{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "90c65bc4",
   "metadata": {},
   "source": [
    "# 1.检测数据集处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35577fdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "yolo_ct_train = '/nfs3-p2/zsxm/dataset/aorta_yolo_ct/images/train'\n",
    "yolo_ct_val = '/nfs3-p2/zsxm/dataset/aorta_yolo_ct/images/val'\n",
    "yolo_ct_test = '/nfs3-p2/zsxm/dataset/aorta_yolo_ct/images/test'\n",
    "yolo_cta_train = '/nfs3-p2/zsxm/dataset/aorta_yolo_cta/images/train'\n",
    "yolo_cta_val = '/nfs3-p2/zsxm/dataset/aorta_yolo_cta/images/val'\n",
    "yolo_cta_test = '/nfs3-p2/zsxm/dataset/aorta_yolo_cta/images/test'\n",
    "\n",
    "def list_patient(root):\n",
    "    ps = set()\n",
    "    for img in os.listdir(root):\n",
    "        ps.add(img.split('_')[0])\n",
    "    print(root, len(ps))\n",
    "    #print(sorted(list(ps)))\n",
    "    print('*********************************************************************************************************************************')\n",
    "    \n",
    "list_patient(yolo_ct_train)\n",
    "list_patient(yolo_ct_val)\n",
    "list_patient(yolo_ct_test)\n",
    "list_patient(yolo_cta_train)\n",
    "list_patient(yolo_cta_val)\n",
    "list_patient(yolo_cta_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a383e3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = {'/nfs3-p2/zsxm/dataset/2021-07-23-10-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-23-4':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-30':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-08':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-13':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-17-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-19':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-28':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-29-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-aa':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-pau':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-pau':list(),    \n",
    "          }\n",
    "for k, v in dataset.items():\n",
    "    v.extend(sorted(os.listdir(k)))\n",
    "\n",
    "def list_patient(root):\n",
    "    print(f'\\n\\n********************************{root}*******************************')\n",
    "    ps = set()\n",
    "    for img in os.listdir(root):\n",
    "        ps.add(img.split('_')[0])\n",
    "    ps = sorted(list(ps))\n",
    "    flag = [False] * len(ps)\n",
    "    for k, v in dataset.items():\n",
    "        ls = []\n",
    "        for i, patient in enumerate(ps):\n",
    "            if patient in v:\n",
    "                ls.append(patient)\n",
    "                flag[i] = True\n",
    "        if len(ls) > 0:\n",
    "            print('--------------------------')\n",
    "            print(k, ls)\n",
    "    ls = []\n",
    "    for i, f in enumerate(flag):\n",
    "        if not f:\n",
    "            ls.append(ps[i])\n",
    "    if len(ls) > 0:\n",
    "        print('----------------------------------------------------')\n",
    "        print('not find', ls)\n",
    "    \n",
    "list_patient(yolo_ct_train)\n",
    "list_patient(yolo_ct_val)\n",
    "list_patient(yolo_ct_test)\n",
    "list_patient(yolo_cta_train)\n",
    "list_patient(yolo_cta_val)\n",
    "list_patient(yolo_cta_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cf1cc34",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "\n",
    "def check_height_equal_width(root):\n",
    "    for patient in tqdm(sorted(os.listdir(root)), desc=root):\n",
    "        if os.path.isfile(os.path.join(root, patient)):\n",
    "            continue\n",
    "        if os.path.exists(os.path.join(root, patient, '1')):\n",
    "            img_path = os.path.join(root, patient, '1', 'images_-100_500')\n",
    "            try:\n",
    "                img = Image.open(os.path.join(img_path, os.listdir(img_path)[0]))\n",
    "            except FileNotFoundError:\n",
    "                img_path = os.path.join(root, patient, '1', 'images')\n",
    "                img = Image.open(os.path.join(img_path, os.listdir(img_path)[0]))\n",
    "            if img.height != img.width:\n",
    "                print(os.path.join(root, patient), '1', img.height, img.width)\n",
    "        else:\n",
    "            print(os.path.join(root, patient), '1 not exist')\n",
    "        if os.path.exists(os.path.join(root, patient, '2')):\n",
    "            img_path = os.path.join(root, patient, '2', 'images_-100_500')\n",
    "            try:\n",
    "                img = Image.open(os.path.join(img_path, os.listdir(img_path)[0]))\n",
    "            except FileNotFoundError:\n",
    "                img_path = os.path.join(root, patient, '2', 'images')\n",
    "                img = Image.open(os.path.join(img_path, os.listdir(img_path)[0]))\n",
    "            if img.height != img.width:\n",
    "                print(os.path.join(root, patient), '2', img.height, img.width)\n",
    "        else:\n",
    "            print(os.path.join(root, patient), '2 not exist')\n",
    "            \n",
    "dataset = {'/nfs3-p2/zsxm/dataset/2021-07-23-10-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-23-4':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-30':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-08':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-13':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-17-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-19':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-28':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-29-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-aa':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-pau':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-pau':list(),    \n",
    "          }\n",
    "\n",
    "for root in dataset:\n",
    "    check_height_equal_width(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91858812",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "\n",
    "def check_height_equal_width(root):\n",
    "    for name in tqdm(sorted(os.listdir(root))):\n",
    "        img = Image.open(os.path.join(root, name))\n",
    "        if img.height != img.width:\n",
    "            print(root, name, img.height, img.width)\n",
    "            \n",
    "yolo_ct_train = '/nfs3-p2/zsxm/dataset/aorta_yolo_ct/images/train'\n",
    "yolo_ct_val = '/nfs3-p2/zsxm/dataset/aorta_yolo_ct/images/val'\n",
    "yolo_cta_train = '/nfs3-p2/zsxm/dataset/aorta_yolo_cta/images/train'\n",
    "yolo_cta_val = '/nfs3-p2/zsxm/dataset/aorta_yolo_cta/images/val'\n",
    "\n",
    "check_height_equal_width(yolo_ct_train)\n",
    "check_height_equal_width(yolo_ct_val)\n",
    "check_height_equal_width(yolo_cta_train)\n",
    "check_height_equal_width(yolo_cta_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5aba535d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "from tqdm import tqdm\n",
    "\n",
    "in_root = '/nfs3-p1/zsxm/dataset/aorta_cta_img_label/0yolo_paper_add_22.8.8'\n",
    "out_root = '/nfs3-p1/zsxm/dataset/aorta_yolo_cta'\n",
    "\n",
    "images_path = os.path.join(out_root, \"images\")\n",
    "labels_path = os.path.join(out_root, \"labels\")\n",
    "\n",
    "for cate in ['train', 'val', 'test']:\n",
    "    cate_imgs = os.path.join(images_path, cate)\n",
    "    cate_lbs = os.path.join(labels_path, cate)\n",
    "    os.makedirs(cate_imgs, exist_ok=True)\n",
    "    os.makedirs(cate_lbs, exist_ok=True)\n",
    "    cate_root = os.path.join(in_root, cate)\n",
    "    for patient in tqdm(sorted(os.listdir(cate_root))):\n",
    "        patient_images = os.path.join(cate_root, patient, 'images')\n",
    "        patient_labels = os.path.join(cate_root, patient, 'labels')\n",
    "        for img in os.listdir(patient_images):\n",
    "            shutil.copy(os.path.join(patient_images, img), cate_imgs)\n",
    "            #print(f'{os.path.join(patient_images, img)} 2 {cate_imgs}')\n",
    "        for lb in os.listdir(patient_labels):\n",
    "            shutil.copy(os.path.join(patient_labels, lb), cate_lbs)\n",
    "            #print(f'{os.path.join(patient_labels, lb)} 2 {cate_lbs}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22d28182",
   "metadata": {},
   "outputs": [],
   "source": [
    "yolo主动脉检测的平扫和增强数据集分别在2022.8.8增加了新的测试集和几个加入到原来训练集和验证集的样本"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3eda4d8",
   "metadata": {},
   "source": [
    "# 2.分类数据集处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "346a26b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "ct_train = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train'\n",
    "ct_val = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val'\n",
    "cta_train = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/train'\n",
    "cta_val = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/val'\n",
    "\n",
    "def list_patient(root):\n",
    "    ps = set()\n",
    "    for cate in ['0', '1', '2']:\n",
    "        for img in os.listdir(os.path.join(root, cate)):\n",
    "            ps.add(img.split('_')[0])\n",
    "    print(root, len(ps))\n",
    "#     print(sorted(list(ps)))\n",
    "#     print('*********************************************************************************************************************************')\n",
    "    return ps\n",
    "    \n",
    "cttr = list_patient(ct_train)\n",
    "ctva = list_patient(ct_val)\n",
    "ctatr = list_patient(cta_train)\n",
    "ctava = list_patient(cta_val)\n",
    "print(cttr == ctatr, ctva == ctava)\n",
    "\n",
    "print(cttr-ctatr)\n",
    "print(ctatr-cttr)\n",
    "print(ctva-ctava)\n",
    "print(ctava-ctva)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f2196b",
   "metadata": {},
   "outputs": [],
   "source": [
    "center_ct_train = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train'\n",
    "center_ct_val = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val'\n",
    "center_cta_train = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/train'\n",
    "center_cta_val = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/val'\n",
    "\n",
    "ccttr = list_patient(center_ct_train)\n",
    "cctva = list_patient(center_ct_val)\n",
    "cctatr = list_patient(center_cta_train)\n",
    "cctava = list_patient(center_cta_val)\n",
    "print(cttr == ccttr, ctva == cctva, ctatr == cctatr, ctava == cctava)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51e50480",
   "metadata": {},
   "outputs": [],
   "source": [
    "#从训练集中选出测试集\n",
    "import random\n",
    "random.seed(7987)\n",
    "n = 95\n",
    "n0, n1, n2 = 46, 38, 11\n",
    "assert n0+n1+n2 == n\n",
    "\n",
    "def list_patient_cate(root):\n",
    "    ps = [set(), set(), set()]\n",
    "    for i, cate in enumerate(['0', '1', '2']):\n",
    "        for img in os.listdir(os.path.join(root, cate)):\n",
    "            ps[i].add(img.split('_')[0])\n",
    "    return ps\n",
    "\n",
    "cttrain = list_patient_cate(ct_train)\n",
    "\n",
    "for ps in cttrain:\n",
    "    print(len(ps))\n",
    "\n",
    "cttest = [random.sample(cttrain[0], n0), random.sample(cttrain[1], n1), random.sample(cttrain[2], n2)]\n",
    "ctte = []\n",
    "for te in cttest:\n",
    "    ctte.extend(te)\n",
    "ctte = set(ctte)\n",
    "print(sum([len(te) for te in cttest]), len(ctte), ctte)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45dac9ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctatrain = list_patient_cate(cta_train)\n",
    "print(set(cttest[0])-ctatrain[0], '\\n')\n",
    "print(set(cttest[1])-ctatrain[1], '\\n')\n",
    "print(set(cttest[2])-ctatrain[2], '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f649f0a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1b1c76",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#移动测试集\n",
    "import shutil\n",
    "\n",
    "def move_test(root):\n",
    "    test_dir = os.path.join(root, 'test')\n",
    "    os.makedirs(test_dir, exist_ok=True)\n",
    "    for i, cate in enumerate(['0', '1', '2']):\n",
    "        orig_cate_dir = os.path.join(root, 'train', cate)\n",
    "        print(root, cate, len(os.listdir(orig_cate_dir)))\n",
    "        test_cate_dir = os.path.join(test_dir, cate)\n",
    "        os.makedirs(test_cate_dir, exist_ok=True)\n",
    "        for img in os.listdir(orig_cate_dir):\n",
    "            if img.split('_')[0] in cttest[i]:\n",
    "                shutil.move(os.path.join(orig_cate_dir, img), os.path.join(test_cate_dir, img))\n",
    "                pass\n",
    "        print(root, cate, len(os.listdir(orig_cate_dir)), len(os.listdir(test_cate_dir)), '\\n')\n",
    "                \n",
    "move_test('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500')\n",
    "move_test('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center')\n",
    "move_test('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500')\n",
    "move_test('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634a4bd9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "list_patient('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/test')\n",
    "list_patient('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test')\n",
    "list_patient('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/test')\n",
    "list_patient('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/test')\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55af5305",
   "metadata": {},
   "outputs": [],
   "source": [
    "#将错误移动的测试集移动回去\n",
    "def move_back(root):\n",
    "    test_dir = os.path.join(root, 'test')\n",
    "    for i, cate in enumerate(['0', '1', '2']):\n",
    "        orig_cate_dir = os.path.join(root, 'train', cate)\n",
    "        test_cate_dir = os.path.join(test_dir, cate)\n",
    "        print(root, cate, len(os.listdir(orig_cate_dir)), len(os.listdir(test_cate_dir)))\n",
    "        for img in os.listdir(test_cate_dir):\n",
    "            shutil.move(os.path.join(test_cate_dir, img), os.path.join(orig_cate_dir, img))\n",
    "        print(root, cate, len(os.listdir(orig_cate_dir)), '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa0fa2b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_patient_num(root):\n",
    "    ps = [set(), set(), set()]\n",
    "    for i, cate in enumerate(['0', '1', '2']):\n",
    "        for img in os.listdir(os.path.join(root, cate)):\n",
    "            ps[i].add(img.split('_')[0])\n",
    "    print(root, end=': ')\n",
    "    total = 0\n",
    "    for p in ps:\n",
    "        print(len(p), end='/')\n",
    "        total += len(p)\n",
    "    print(f'={total}','\\n')\n",
    "\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/test')\n",
    "\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/train')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/val')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/test')\n",
    "\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test')\n",
    "\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/train')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/val')\n",
    "count_patient_num('/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0914cb3e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "377f0431",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "ct_train = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train'\n",
    "ct_val = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val'\n",
    "ct_test = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/test'\n",
    "cta_train = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/train'\n",
    "cta_val = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/val'\n",
    "cta_test = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/test'\n",
    "\n",
    "def list_patient(root):\n",
    "    ps = set()\n",
    "    for cate in ['0', '1', '2']:\n",
    "        for img in os.listdir(os.path.join(root, cate)):\n",
    "            ps.add(img.split('_')[0])\n",
    "    print(root, len(ps))\n",
    "    return ps\n",
    "    \n",
    "cttr = list_patient(ct_train)\n",
    "ctva = list_patient(ct_val)\n",
    "ctte = list_patient(ct_test)\n",
    "ctatr = list_patient(cta_train)\n",
    "ctava = list_patient(cta_val)\n",
    "ctate = list_patient(cta_test)\n",
    "\n",
    "print(cttr-ctatr)\n",
    "print(ctatr-cttr)\n",
    "print(ctva-ctava)\n",
    "print(ctava-ctva)\n",
    "print(ctte-ctate)\n",
    "print(ctate-ctte)\n",
    "#所以CTA比CT少的病例都是阴性09.17的，因为标注不完整\n",
    "'''\n",
    "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train 690\n",
    "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val 188\n",
    "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/test 95\n",
    "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/train 637\n",
    "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/val 178\n",
    "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/test 90\n",
    "{'xiaofuxin-25-33-64-153', 'zhanmin-18-23-37-88', 'zhanghonghu-37-45-69-149', 'wanghaifang-40-47-83-160', 'zhouxinya-24-30-40-120', 'zhangsuyun-18-21-37-81', 'weifengli-26-34-62-147', 'wuhanjiang-18-22-39-85', 'yangzengrong-37-48-78-188', 'wutiandi-21-28-54-120', 'zhangmeiqin-19-22-39-86', 'zhouzhangyun-26-35-61-136', 'zhuxiaoqin-22-29-50-130', 'zhaosuifang-23-34-61-154', 'xurihua-26-34-61-145', 'yuanxiazhen-14-18-36-76', 'zhaoyijun-17-20-33-82', 'wumingshi2021061701-25-28-45-98', 'zhuxinnan-18-26-49-129', 'zhudamin-33-41-70-156', 'zhoushijun-30-37-64-145', 'zhouweiqiang-19-24-42-90', 'yuanxicai-35-43-85-169', 'yubaisong-25-33-65-152', 'zhangguofeng-25-30-45-97', 'xuyuetian-22-30-53-132', 'zhangguanghong-39-46-68-156', 'zhengxuefei-21-24-43-86', 'xiyongmin-25-35-62-162', 'yuyunfu-26-33-66-148', 'zhangyitian-26-32-55-148', 'wangdahua-19-24-33-89', 'zengshengjiang-19-23-36-89', 'yinggenhua-36-45-72-174', 'wuyanxian-17-21-36-83', 'zhouyongjin-17-21-35-87', 'xiabao-33-42-78-196', 'wuqinmei-24-32-52-140', 'weihuiqiong-27-33-58-137', 'yangshujuan-30-39-62-160', 'yesenfei-22-32-53-172', 'zhengyouduo-23-27-41-83', 'zhengfahong-22-26-44-99', 'zhumeizheng-34-42-69-167', 'zhoubo-19-26-41-137', 'wangronglin-16-21-28-81', 'zhengjuying-16-19-33-86', 'wuruqing-49-60-96-153', 'yuzhixue-16-21-38-91', 'wumingshi2021080102-38-47-85-156', 'caiqishu-34-43-70-176', 'yangjinnan-31-38-67-146', 'panxiangyang-45-53-97-196'}\n",
    "set()\n",
    "{'zoudeling-11-15-30-84', 'zhangjunfei-25-33-52-125', 'zhouxiushan-14-18-34-84', 'zhaomin-11-15-32-86', 'zhangkesheng-27-34-63-144', 'wenshuijuan-30-36-66-145', 'yangtangyong-15-19-35-85', 'xiasuifu-12-18-30-84', 'zhangjingjing-16-19-30-78', 'zhoushan-17-20-39-78'}\n",
    "set()\n",
    "{'wuzongtang-29-36-76-141', 'zhouhaibin-30-39-74-185', 'zhangjingzhong-53-59-85-174', 'yuyinglong-24-30-51-98', 'zhaorong-33-42-71-163'}\n",
    "set()\n",
    "'''"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e46c4c5",
   "metadata": {},
   "source": [
    "# 3.查看检测和分类数据集的重合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7c0d116",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def list_patient_de(root):\n",
    "    ps = set()\n",
    "    for dataset in ['train', 'val', 'test']:\n",
    "        dataset_path = os.path.join(root, dataset)\n",
    "        for img in os.listdir(dataset_path):\n",
    "            ps.add(img.split('_')[0])\n",
    "    return ps\n",
    "\n",
    "def list_patient_cl(root):\n",
    "    ps = set()\n",
    "    for dataset in ['train', 'val', 'test']:\n",
    "        dataset_path = os.path.join(root, dataset)\n",
    "        for cate in ['0','1','2']:\n",
    "            cate_path = os.path.join(dataset_path, cate)\n",
    "            for img in os.listdir(cate_path):\n",
    "                ps.add(img.split('_')[0])\n",
    "    return ps\n",
    "\n",
    "de_ct = list_patient_de('/nfs3-p2/zsxm/dataset/aorta_yolo_ct/images')\n",
    "de_cta = list_patient_de('/nfs3-p2/zsxm/dataset/aorta_yolo_cta/images')\n",
    "cl_ct = list_patient_cl('/nfs3-p2/zsxm/dataset/aorta_classify_ct_-100_500')\n",
    "cl_cta = list_patient_cl('/nfs3-p2/zsxm/dataset/aorta_classify_cta_-100_500')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f610ce1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(de_ct), len(de_cta), len(cl_ct), len(cl_cta))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c041cf91",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(de_ct-cl_ct), len(cl_ct-de_ct))\n",
    "print(len(de_cta-cl_ct), len(cl_ct-de_cta))\n",
    "print(len(de_ct-cl_cta), len(cl_cta-de_ct))\n",
    "print(len(de_cta-cl_cta), len(cl_cta-de_cta))\n",
    "# 分类数据集和ct检测有7个重复，和cta检测有6个重复"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54fd69f5",
   "metadata": {},
   "source": [
    "# 4.统计数据集信息到Excel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1aeaba5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pydicom\n",
    "import openpyxl\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "71771808",
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset.file_meta -------------------------------\n",
      "(0002, 0000) File Meta Information Group Length  UL: 198\n",
      "(0002, 0001) File Meta Information Version       OB: b'\\x00\\x01'\n",
      "(0002, 0002) Media Storage SOP Class UID         UI: CT Image Storage\n",
      "(0002, 0003) Media Storage SOP Instance UID      UI: 1.2.840.113619.2.404.3.313266434.891.1612325248.768.1\n",
      "(0002, 0010) Transfer Syntax UID                 UI: JPEG Lossless, Non-Hierarchical, First-Order Prediction (Process 14 [Selection Value 1])\n",
      "(0002, 0012) Implementation Class UID            UI: 9.520.2.3.1\n",
      "(0002, 0013) Implementation Version Name         SH: 'DICOMLIB231'\n",
      "(0002, 0016) Source Application Entity Title     AE: 'eRadPACS02'\n",
      "-------------------------------------------------\n",
      "(0008, 0005) Specific Character Set              CS: 'ISO_IR 100'\n",
      "(0008, 0008) Image Type                          CS: ['ORIGINAL', 'PRIMARY', 'AXIAL']\n",
      "(0008, 0012) Instance Creation Date              DA: '20210212'\n",
      "(0008, 0013) Instance Creation Time              TM: '162030'\n",
      "(0008, 0016) SOP Class UID                       UI: CT Image Storage\n",
      "(0008, 0018) SOP Instance UID                    UI: 1.2.840.113619.2.404.3.313266434.891.1612325248.768.1\n",
      "(0008, 0020) Study Date                          DA: '20210212'\n",
      "(0008, 0021) Series Date                         DA: '20210212'\n",
      "(0008, 0022) Acquisition Date                    DA: '20210212'\n",
      "(0008, 0023) Content Date                        DA: '20210212'\n",
      "(0008, 0030) Study Time                          TM: '161611'\n",
      "(0008, 0031) Series Time                         TM: '161801'\n",
      "(0008, 0032) Acquisition Time                    TM: '161903.034214'\n",
      "(0008, 0033) Content Time                        TM: '162030'\n",
      "(0008, 0050) Accession Number                    SH: 'CT0169187'\n",
      "(0008, 0060) Modality                            CS: 'CT'\n",
      "(0008, 0070) Manufacturer                        LO: 'GE MEDICAL SYSTEMS'\n",
      "(0008, 0080) Institution Name                    LO: 'YiWu Central Hosp'\n",
      "(0008, 0090) Referring Physician's Name          PN: ''\n",
      "(0008, 1010) Station Name                        SH: 'ct99'\n",
      "(0008, 1030) Study Description                   LO: 'PROC DESC'\n",
      "(0008, 103e) Series Description                  LO: 'Recon 2: 5mm std c+'\n",
      "(0008, 1050) Performing Physician's Name         PN: ''\n",
      "(0008, 1060) Name of Physician(s) Reading Study  PN: ''\n",
      "(0008, 1070) Operators' Name                     PN: ''\n",
      "(0008, 1090) Manufacturer's Model Name           LO: 'Optima CT680 Series'\n",
      "(0008, 1110)  Referenced Study Sequence  1 item(s) ---- \n",
      "   (0008, 1150) Referenced SOP Class UID            UI: Detached Study Management SOP Class\n",
      "   (0008, 1155) Referenced SOP Instance UID         UI: 2.2.940.473.8013.20210212.1161459.1040.173889.17388\n",
      "   ---------\n",
      "(0008, 1140)  Referenced Image Sequence  1 item(s) ---- \n",
      "   (0008, 1150) Referenced SOP Class UID            UI: CT Image Storage\n",
      "   (0008, 1155) Referenced SOP Instance UID         UI: 1.2.840.113619.2.404.3.313266434.891.1612325248.197.1\n",
      "   ---------\n",
      "(0008, 3010) Irradiation Event UID               UI: 1.2.840.113619.2.404.3.313266434.891.1612325248.331\n",
      "(0009, 0010) Private Creator                     LO: 'GEMS_IDEN_01'\n",
      "(0009, 1001) [Full fidelity]                     LO: 'CT_LIGHTSPEED'\n",
      "(0009, 1002) [Suite id]                          SH: 'CT99'\n",
      "(0009, 1004) [Product id]                        SH: 'Optima CT680 Ser'\n",
      "(0009, 1027) [Image actual date]                 SL: 1613146681\n",
      "(0009, 10e3) [Equipment UID]                     UI: ''\n",
      "(0010, 0010) Patient's Name                      PN: 'Zhang Guo Qiang'\n",
      "(0010, 0020) Patient ID                          LO: '2372954'\n",
      "(0010, 0030) Patient's Birth Date                DA: '19761107'\n",
      "(0010, 0040) Patient's Sex                       CS: 'M'\n",
      "(0010, 1000) Other Patient IDs                   LO: ''\n",
      "(0010, 1001) Other Patient Names                 PN: 'ÕÅ¹úÇ¿'\n",
      "(0010, 1010) Patient's Age                       AS: '044Y'\n",
      "(0010, 21b0) Additional Patient History          LT: ''\n",
      "(0018, 0022) Scan Options                        CS: 'HELICAL MODE'\n",
      "(0018, 0050) Slice Thickness                     DS: '1.25'\n",
      "(0018, 0060) KVP                                 DS: '120.0'\n",
      "(0018, 0088) Spacing Between Slices              DS: '5.0'\n",
      "(0018, 0090) Data Collection Diameter            DS: '500.0'\n",
      "(0018, 1000) Device Serial Number                LO: '*'\n",
      "(0018, 1020) Software Versions                   LO: 'kl64P1_5.2'\n",
      "(0018, 1030) Protocol Name                       LO: '6.12 Abdomen Pelvis 0.8sec (chang xi mo CTA/CTV)'\n",
      "(0018, 1100) Reconstruction Diameter             DS: '360.0'\n",
      "(0018, 1110) Distance Source to Detector         DS: '949.147'\n",
      "(0018, 1111) Distance Source to Patient          DS: '541.0'\n",
      "(0018, 1120) Gantry/Detector Tilt                DS: '0.0'\n",
      "(0018, 1130) Table Height                        DS: '141.0'\n",
      "(0018, 1140) Rotation Direction                  CS: 'CW'\n",
      "(0018, 1150) Exposure Time                       IS: '800'\n",
      "(0018, 1151) X-Ray Tube Current                  IS: '316'\n",
      "(0018, 1152) Exposure                            IS: '5'\n",
      "(0018, 1160) Filter Type                         SH: 'BODY FILTER'\n",
      "(0018, 1170) Generator Power                     IS: '60000'\n",
      "(0018, 1190) Focal Spot(s)                       DS: '1.2'\n",
      "(0018, 1210) Convolution Kernel                  SH: 'STANDARD'\n",
      "(0018, 5100) Patient Position                    CS: 'HFS'\n",
      "(0018, 9305) Revolution Time                     FD: 0.8\n",
      "(0018, 9306) Single Collimation Width            FD: 0.625\n",
      "(0018, 9307) Total Collimation Width             FD: 40.0\n",
      "(0018, 9309) Table Speed                         FD: 68.75\n",
      "(0018, 9310) Table Feed per Rotation             FD: 55.0\n",
      "(0018, 9311) Spiral Pitch Factor                 FD: 1.375\n",
      "(0019, 0010) Private Creator                     LO: 'GEMS_ACQU_01'\n",
      "(0019, 1002) [Detector Channel]                  SL: 848\n",
      "(0019, 1003) [Cell number at Theta]              DS: '389.75'\n",
      "(0019, 1004) [Cell spacing]                      DS: '1.0947'\n",
      "(0019, 100f) [Horiz. Frame of ref.]              DS: '777.5'\n",
      "(0019, 1011) [Series contrast]                   SS: 0\n",
      "(0019, 1018) [First scan ras]                    LO: 'I'\n",
      "(0019, 101a) [Last scan ras]                     LO: 'I'\n",
      "(0019, 1023) [Table Speed [mm/rotation]]         DS: '55.0'\n",
      "(0019, 1024) [Mid Scan Time [sec]]               DS: '113.420845'\n",
      "(0019, 1025) [Mid scan flag]                     SS: 1\n",
      "(0019, 1026) [Tube Azimuth [degree]]             SL: -1\n",
      "(0019, 1027) [Rotation Speed [msec]]             DS: '0.8'\n",
      "(0019, 102c) [Number of triggers]                SL: 12603\n",
      "(0019, 102e) [Angle of first view]               DS: '0.0'\n",
      "(0019, 102f) [Trigger frequency]                 DS: '1230.0'\n",
      "(0019, 1039) [SFOV Type]                         SS: 1024\n",
      "(0019, 1042) [Segment Number]                    SS: 0\n",
      "(0019, 1043) [Total Segments Required]           SS: 0\n",
      "(0019, 1047) [View compression factor]           SS: 1\n",
      "(0019, 1052) [Recon post proc. Flag]             SS: 1\n",
      "(0019, 106a) [Dependent on #views processed]     SS: 3\n",
      "(0020, 000d) Study Instance UID                  UI: 2.2.940.473.8013.20210212.1161459.1040.173889.17388\n",
      "(0020, 000e) Series Instance UID                 UI: 1.2.840.113619.2.404.3.313266434.891.1612325248.330.3149825\n",
      "(0020, 0010) Study ID                            SH: '13252'\n",
      "(0020, 0011) Series Number                       IS: '5'\n",
      "(0020, 0012) Acquisition Number                  IS: '1'\n",
      "(0020, 0013) Instance Number                     IS: '1'\n",
      "(0020, 0032) Image Position (Patient)            DS: [-192.300, -180.000, -80.250]\n",
      "(0020, 0037) Image Orientation (Patient)         DS: [1.000000, 0.000000, 0.000000, 0.000000, 1.000000, 0.000000]\n",
      "(0020, 0052) Frame of Reference UID              UI: 1.2.840.113619.2.404.3.313266434.891.1612325248.194.7775.1\n",
      "(0020, 1040) Position Reference Indicator        LO: 'XY'\n",
      "(0020, 1041) Slice Location                      DS: '-80.25'\n",
      "(0021, 0010) Private Creator                     LO: 'GEMS_RELA_01'\n",
      "(0021, 1003) [Series from which Prescribed]      SS: 3\n",
      "(0021, 1035) [Series from which prescribed]      SS: 1\n",
      "(0021, 1036) [Image from which prescribed]       SS: 1\n",
      "(0021, 1091) [Biopsy position]                   SS: 0\n",
      "(0021, 1092) [Biopsy T location]                 FL: 0.0\n",
      "(0021, 1093) [Biopsy ref location]               FL: 0.0\n",
      "(0023, 0010) Private Creator                     LO: 'GEMS_STDY_01'\n",
      "(0023, 1070) [Start time(secs) in first axial]   FD: 1613146630.076781\n",
      "(0027, 0010) Private Creator                     LO: 'GEMS_IMAG_01'\n",
      "(0027, 1010) [Scout Type]                        SS: 0\n",
      "(0027, 101c) [Vma mamp]                          SL: 0\n",
      "(0027, 101e) [Vma mod]                           SL: 0\n",
      "(0027, 101f) [Vma clip]                          SL: 98\n",
      "(0027, 1020) [Smart scan ON/OFF flag]            SS: 2\n",
      "(0027, 1035) [Plane Type]                        SS: 2\n",
      "(0027, 1042) [Center R coord of plane image]     FL: 12.300000190734863\n",
      "(0027, 1043) [Center A coord of plane image]     FL: 0.0\n",
      "(0027, 1044) [Center S coord of plane image]     FL: -80.25\n",
      "(0027, 1045) [Normal R coord]                    FL: 0.0\n",
      "(0027, 1046) [Normal A coord]                    FL: -0.0\n",
      "(0027, 1047) [Normal S coord]                    FL: 1.0\n",
      "(0027, 1050) [Scan Start Location]               FL: 0.0\n",
      "(0027, 1051) [Scan End Location]                 FL: 0.0\n",
      "(0028, 0002) Samples per Pixel                   US: 1\n",
      "(0028, 0004) Photometric Interpretation          CS: 'MONOCHROME2'\n",
      "(0028, 0010) Rows                                US: 512\n",
      "(0028, 0011) Columns                             US: 512\n",
      "(0028, 0030) Pixel Spacing                       DS: [0.703125, 0.703125]\n",
      "(0028, 0100) Bits Allocated                      US: 16\n",
      "(0028, 0101) Bits Stored                         US: 16\n",
      "(0028, 0102) High Bit                            US: 15\n",
      "(0028, 0103) Pixel Representation                US: 1\n",
      "(0028, 0120) Pixel Padding Value                 SS: -2000\n",
      "(0028, 1050) Window Center                       DS: '39.0'\n",
      "(0028, 1051) Window Width                        DS: '399.0'\n",
      "(0028, 1052) Rescale Intercept                   DS: '-1024.0'\n",
      "(0028, 1053) Rescale Slope                       DS: '1.0'\n",
      "(0028, 1054) Rescale Type                        LO: 'HU'\n",
      "(0040, 0244) Performed Procedure Step Start Date DA: '20210212'\n",
      "(0040, 0245) Performed Procedure Step Start Time TM: '161611'\n",
      "(0040, 0253) Performed Procedure Step ID         SH: 'PPS ID  13252'\n",
      "(0040, 0254) Performed Procedure Step Descriptio LO: 'PROC DESC'\n",
      "(0040, 0275)  Request Attributes Sequence  1 item(s) ---- \n",
      "   (0008, 0050) Accession Number                    SH: 'CT0169187'\n",
      "   (0008, 1110)  Referenced Study Sequence  1 item(s) ---- \n",
      "      (0008, 1150) Referenced SOP Class UID            UI: Detached Study Management SOP Class\n",
      "      (0008, 1155) Referenced SOP Instance UID         UI: 2.2.940.473.8013.20210212.1161459.1040.173889.17388\n",
      "      ---------\n",
      "   (0020, 000d) Study Instance UID                  UI: 2.2.940.473.8013.20210212.1161459.1040.173889.17388\n",
      "   (0032, 1060) Requested Procedure Description     LO: 'PROC DESC'\n",
      "   (0040, 0007) Scheduled Procedure Step Descriptio LO: 'STEP DESC'\n",
      "   (0040, 0009) Scheduled Procedure Step ID         SH: 'STEP ID'\n",
      "   (0040, 1001) Requested Procedure ID              SH: 'PROC ID'\n",
      "   ---------\n",
      "(0043, 0010) Private Creator                     LO: 'GEMS_PARM_01'\n",
      "(0043, 1010) [Window value]                      US: 400\n",
      "(0043, 1012) [X-ray chain]                       SS: [99, 99, 99]\n",
      "(0043, 1016) [Number of overranges]              SS: -1\n",
      "(0043, 101e) [Delta Start Time [msec]]           DS: '0.063415'\n",
      "(0043, 101f) [Max overranges in a view]          SL: 0\n",
      "(0043, 1021) [Corrected after glow terms]        SS: 0\n",
      "(0043, 1025) [Reference channels]                SS: [0, 0, 0, 0, 0, 0]\n",
      "(0043, 1026) [No views ref chans blocked]        US: [0, 0, 0, 0]\n",
      "(0043, 1027) [Scan Pitch Ratio]                  SH: '1.375:1'\n",
      "(0043, 1028) [Unique image iden]                 OB: b'00'\n",
      "(0043, 102b) [Private Scan Options]              SS: [2, 0, 0, 0]\n",
      "(0043, 1031) [Recon Center Coordinates]          DS: [12.300000, 0.000000]\n",
      "(0043, 1040) [Trigger on position]               FL: 359.2698059082031\n",
      "(0043, 1041) [Degree of rotation]                FL: 4610.853515625\n",
      "(0043, 1042) [DAS trigger source]                SL: 0\n",
      "(0043, 1043) [DAS fpa gain]                      SL: 0\n",
      "(0043, 1044) [DAS output source]                 SL: 0\n",
      "(0043, 1045) [DAS ad input]                      SL: 0\n",
      "(0043, 1046) [DAS cal mode]                      SL: 0\n",
      "(0043, 104d) [Start scan to X-ray on delay]      FL: 0.0\n",
      "(0043, 104e) [Duration of X-ray on]              FL: 10.246341705322266\n",
      "(0043, 1064) [Image Filter]                      CS: ''\n",
      "(0045, 0010) Private Creator                     LO: 'GEMS_HELIOS_01'\n",
      "(0045, 1001) [Number of Macro Rows in Detector]  SS: 64\n",
      "(0045, 1002) [Macro width at ISO Center]         FL: 0.625\n",
      "(0045, 1003) [DAS type]                          SS: 27\n",
      "(0045, 1004) [DAS gain]                          SS: 4\n",
      "(0045, 1006) [Table Direction]                   CS: 'INTO GANTRY'\n",
      "(0045, 1007) [Z smoothing Factor]                FL: 0.0\n",
      "(0045, 1008) [View Weighting Mode]               SS: 0\n",
      "(0045, 1009) [Sigma Row number]                  SS: 0\n",
      "(0045, 100a) [Minimum DAS value]                 FL: 0.0\n",
      "(0045, 100b) [Maximum Offset Value]              FL: 0.0\n",
      "(0045, 100c) [Number of Views shifted]           SS: 0\n",
      "(0045, 100d) [Z tracking Flag]                   SS: 0\n",
      "(0045, 100e) [Mean Z error]                      FL: 0.0\n",
      "(0045, 100f) [Z tracking Error]                  FL: 0.0\n",
      "(0045, 1010) [Start View 2A]                     SS: 0\n",
      "(0045, 1011) [Number of Views 2A]                SS: 0\n",
      "(0045, 1012) [Start View 1A]                     SS: 0\n",
      "(0045, 1013) [Sigma Mode]                        SS: 0\n",
      "(0045, 1014) [Number of Views 1A]                SS: 0\n",
      "(0045, 1015) [Start View 2B]                     SS: 0\n",
      "(0045, 1016) [Number Views 2B]                   SS: 0\n",
      "(0045, 1017) [Start View 1B]                     SS: 0\n",
      "(0045, 1018) [Number of Views 1B]                SS: 0\n",
      "(0045, 1021) [Iterbone Flag]                     SS: 0\n",
      "(0045, 1022) [Perisstaltic Flag]                 SS: 0\n",
      "(0045, 1032) [TemporalResolution]                FL: 0.800000011920929\n",
      "(0045, 103b) [NoiseReductionImageFilterDesc]     LO: ''\n",
      "(0053, 0010) Private Creator                     LO: 'GEHC_CT_ADVAPP_001'\n",
      "(0053, 1020) [ShuttleFlag]                       IS: '0'\n",
      "(0053, 1060) [reconFlipRotateAnno]               SH: ''\n",
      "(0053, 1061) [highResolutionFlag]                SH: '0'\n",
      "(0053, 1062) [RespiratoryFlag]                   SH: '0'\n",
      "(0053, 1064) Private tag data                    IS: '1'\n",
      "(0053, 1065) Private tag data                    IS: '37'\n",
      "(0053, 1066) Private tag data                    LO: '120kV'\n",
      "(0053, 1067) Private tag data                    IS: '0'\n",
      "(0053, 1068) Private tag data                    IS: '0'\n",
      "(0053, 106a) Private tag data                    IS: '0'\n",
      "(0053, 106b) Private tag data                    IS: '0'\n",
      "(0053, 106f) Private tag data                    IS: '0'\n",
      "(0053, 109d) Private tag data                    LO: ''\n",
      "(7fe0, 0010) Pixel Data                          OB: Array of 198546 elements\n"
     ]
    }
   ],
   "source": [
    "ct = pydicom.dcmread('/nfs3-p1/zsxm/dataset/2021-07-23-4/CTA1(+)/2/exported0000.dcm')\n",
    "print(ct)\n",
    "# x = ct.PixelSpacing\n",
    "# print(x, type(x))\n",
    "# list(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4d8e8852",
   "metadata": {},
   "outputs": [],
   "source": [
    "xlsx_path = '/nfs3-p1/zsxm/dataset/paper_classification.xlsx'\n",
    "wb = openpyxl.Workbook()\n",
    "wb.create_sheet(title='total')\n",
    "title_list = ['patient', 'folder', 'PatientFolderName', 'PatientID', 'PatientName', 'Sex', 'Age', 'AccessionNumber']\n",
    "title_list.extend([ 'Date', 'Manufacturer', 'Institution', 'SpacingBetweenSlices', 'SliceThickness', 'PixelSpacing'])\n",
    "wb['total'].append(title_list)\n",
    "\n",
    "dataset = {'/nfs3-p2/zsxm/dataset/2021-07-23-10-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-23-4':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-30':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-08':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-13':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-17-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-19':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-28':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-29-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-aa':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-pau':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-pau':list(),    \n",
    "          }\n",
    "dataset2 = {}\n",
    "for k, v in dataset.items():\n",
    "    v.extend(sorted(list(filter(lambda x: os.path.isdir(os.path.join(k, x)), os.listdir(k)))))\n",
    "    dataset2[k] = list(map(lambda x: x.split('-')[0], v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c02f0d3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def statistics_dataset(root, sheet, wb):\n",
    "    def func(x):\n",
    "        try:\n",
    "            int(x)\n",
    "            return True\n",
    "        except ValueError:\n",
    "            return False\n",
    "    def find_dcm(x):\n",
    "        return x.endswith('.dcm') or '.' not in x\n",
    "    \n",
    "    try:\n",
    "        sht = wb[sheet]\n",
    "    except KeyError:\n",
    "        wb.create_sheet(title=sheet)\n",
    "        sht = wb[sheet]\n",
    "        title_list = ['patient', 'folder', 'PatientFolderName', 'PatientID', 'PatientName', 'Sex', 'Age', 'AccessionNumber']\n",
    "        title_list.extend([ 'Date', 'Manufacturer', 'Institution', 'SpacingBetweenSlices', 'SliceThickness', 'PixelSpacing'])\n",
    "        sht.append(title_list)\n",
    "    shttotal = wb['total']\n",
    "    \n",
    "    ps = set()\n",
    "    files = sorted(os.listdir(root))\n",
    "    if len(files) == 3 and files[0] == '0' and files[1] == '1' and files[2] == '2':\n",
    "        for cate in ['0', '1']:#files: #只记录阴性和夹层，忽略壁内血肿\n",
    "            cate_path = os.path.join(root, cate)\n",
    "            for img in os.listdir(cate_path):\n",
    "                ps.add(img.split('_')[0])\n",
    "    else:\n",
    "        for img in files:\n",
    "            ps.add(img.split('_')[0])\n",
    "    ps = sorted(list(ps))\n",
    "    \n",
    "    try:\n",
    "        for patient in tqdm(ps, desc=root, ncols=100):\n",
    "            folder_list = []\n",
    "            for k, v in dataset.items():\n",
    "                if patient in v:\n",
    "                    folder_list.append(k)\n",
    "            assert len(folder_list) < 2, f'{patient}:{folder_list}'\n",
    "            if len(folder_list) == 0:\n",
    "                patient_literal_name = patient.split('-')[0]\n",
    "                for k, v in dataset2.items():\n",
    "                    if patient_literal_name in v:\n",
    "                        folder_list.append(k)\n",
    "                assert len(folder_list) < 2, f'{patient}:{folder_list}'\n",
    "                folder = folder_list[0]\n",
    "                patient_folder_name_list = [n for n in os.listdir(folder) if n.startswith(patient_literal_name)]\n",
    "                assert len(patient_folder_name_list) == 1, f'{patient}:{folder}:{patient_folder_name_list}'\n",
    "                patient_folder_name = patient_folder_name_list[0]\n",
    "            else:\n",
    "                folder = folder_list[0]\n",
    "                patient_folder_name = patient\n",
    "            patient_info = [patient, os.path.basename(folder), patient_folder_name]\n",
    "            if 'cta' in sheet:\n",
    "                dcm_path = os.path.join(folder, patient_folder_name, '2')\n",
    "                dcm_list = sorted(filter(lambda x: os.path.isfile(os.path.join(dcm_path,x)) and find_dcm(x), os.listdir(dcm_path)))\n",
    "                dcm = pydicom.dcmread(os.path.join(dcm_path, dcm_list[0]), force=True)\n",
    "            else:\n",
    "                dcm_path = os.path.join(folder, patient_folder_name, '1')\n",
    "                dcm_list = sorted(filter(lambda x: os.path.isfile(os.path.join(dcm_path,x)) and find_dcm(x), os.listdir(dcm_path)))\n",
    "                dcm = pydicom.dcmread(os.path.join(dcm_path, dcm_list[0]), force=True)\n",
    "            \n",
    "            patient_info.extend([dcm.PatientID, str(dcm.PatientName), dcm.PatientSex, int(''.join(filter(func,dcm.PatientAge)))])\n",
    "            \n",
    "            patient_info.extend([dcm.AccessionNumber, dcm.StudyDate, dcm.Manufacturer])\n",
    "            \n",
    "            try:\n",
    "                temp = dcm.InstitutionName\n",
    "                patient_info.append(temp)\n",
    "            except AttributeError:\n",
    "                patient_info.append(None)            \n",
    "            try:\n",
    "                temp = dcm.SpacingBetweenSlices\n",
    "                patient_info.append(temp)\n",
    "            except AttributeError:\n",
    "                patient_info.append(None)\n",
    "            try:\n",
    "                temp = dcm.SliceThickness\n",
    "                patient_info.append(temp)\n",
    "            except AttributeError:\n",
    "                patient_info.append(None)\n",
    "            try:\n",
    "                temp = list(dcm.PixelSpacing)\n",
    "                patient_info.append('/'.join([str(t) for t in temp]))\n",
    "            except AttributeError:\n",
    "                patient_info.append(None)\n",
    "            sht.append(patient_info)\n",
    "            shttotal.append(patient_info)\n",
    "    except Exception as e:\n",
    "        print(patient, folder, e)\n",
    "        raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8e6b6d4a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train:  98%|▉| 597/609 [01:25<00:01,  6.38it/home/zsxm/miniconda3/envs/pytorch/lib/python3.8/site-packages/pydicom/charset.py:714: UserWarning: Value 'GB18030' cannot be used as code extension, ignoring it\n",
      "  py_encodings = _handle_illegal_standalone_encodings(\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train: 100%|█| 609/609 [01:27<00:00,  6.94it\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val: 100%|█| 168/168 [00:24<00:00,  6.86it/s\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test: 100%|█| 84/84 [00:12<00:00,  6.70it/s]\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/train: 100%|█| 556/556 [01:36<00:00,  5.76i\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/val: 100%|█| 158/158 [00:28<00:00,  5.53it/\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/test: 100%|█| 79/79 [00:14<00:00,  5.30it/s\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train: 100%|█| 609/609 [00:48<00:00, 12.46it\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val: 100%|█| 168/168 [00:14<00:00, 11.45it/s\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test: 100%|█| 84/84 [00:07<00:00, 11.35it/s]\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/train: 100%|█| 556/556 [01:28<00:00,  6.30i\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/val: 100%|█| 158/158 [00:25<00:00,  6.11it/\n",
      "/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/test: 100%|█| 79/79 [00:13<00:00,  5.86it/s\n"
     ]
    }
   ],
   "source": [
    "yolo_ct_train = '/nfs3-p1/zsxm/dataset/aorta_yolo_ct/images/train'\n",
    "yolo_ct_val = '/nfs3-p1/zsxm/dataset/aorta_yolo_ct/images/val'\n",
    "yolo_ct_test = '/nfs3-p1/zsxm/dataset/aorta_yolo_ct/images/test'\n",
    "\n",
    "yolo_cta_train = '/nfs3-p1/zsxm/dataset/aorta_yolo_cta/images/train'\n",
    "yolo_cta_val = '/nfs3-p1/zsxm/dataset/aorta_yolo_cta/images/val'\n",
    "yolo_cta_test = '/nfs3-p1/zsxm/dataset/aorta_yolo_cta/images/test'\n",
    "\n",
    "ct_train = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train'\n",
    "ct_val = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val'\n",
    "ct_test = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test'\n",
    "\n",
    "cta_train = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/train'\n",
    "cta_val = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/val'\n",
    "cta_test = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/test'\n",
    "\n",
    "#统计检测数据集\n",
    "# statistics_dataset(yolo_ct_train, 'yolo_ct', wb)\n",
    "# statistics_dataset(yolo_ct_val, 'yolo_ct', wb)\n",
    "# statistics_dataset(yolo_ct_test, 'yolo_ct', wb)\n",
    "\n",
    "# statistics_dataset(yolo_cta_train, 'yolo_cta', wb)\n",
    "# statistics_dataset(yolo_cta_val, 'yolo_cta', wb)\n",
    "# statistics_dataset(yolo_cta_test, 'yolo_cta', wb)\n",
    "\n",
    "# statistics_dataset(yolo_ct_train, 'yolo_ct_train', wb)\n",
    "# statistics_dataset(yolo_ct_val, 'yolo_ct_val', wb)\n",
    "# statistics_dataset(yolo_ct_test, 'yolo_ct_test', wb)\n",
    "\n",
    "# statistics_dataset(yolo_cta_train, 'yolo_cta_train', wb)\n",
    "# statistics_dataset(yolo_cta_val, 'yolo_cta_val', wb)\n",
    "# statistics_dataset(yolo_cta_test, 'yolo_cta_test', wb)\n",
    "\n",
    "#统计分类数据集\n",
    "statistics_dataset(ct_train, 'classify_ct', wb)\n",
    "statistics_dataset(ct_val, 'classify_ct', wb)\n",
    "statistics_dataset(ct_test, 'classify_ct', wb)\n",
    "\n",
    "statistics_dataset(cta_train, 'classify_cta', wb)\n",
    "statistics_dataset(cta_val, 'classify_cta', wb)\n",
    "statistics_dataset(cta_test, 'classify_cta', wb)\n",
    "\n",
    "statistics_dataset(ct_train, 'ct_train', wb)\n",
    "statistics_dataset(ct_val, 'ct_val', wb)\n",
    "statistics_dataset(ct_test, 'ct_test', wb)\n",
    "\n",
    "statistics_dataset(cta_train, 'cta_train', wb)\n",
    "statistics_dataset(cta_val, 'cta_val', wb)\n",
    "statistics_dataset(cta_test, 'cta_test', wb)\n",
    "\n",
    "try:\n",
    "    del wb['Sheet']\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "wb.save(xlsx_path)\n",
    "wb.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8cc51c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# /nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train:  97%|▉| 672/690 [00:48<00:03,  5.79it/disk1/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/pydicom/charset.py:715: UserWarning: Value 'ISO_IR 192' for Specific Character Set does not allow code extensions, ignoring: GB18030\n",
    "#   encodings, py_encodings\n",
    "# /nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train: 100%|█| 690/690 [00:51<00:00, 13.48it\n",
    "# /nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val: 100%|█| 188/188 [00:25<00:00,  7.46it/s\n",
    "# /nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test: 100%|█| 95/95 [00:13<00:00,  7.12it/s]\n",
    "# /nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/train: 100%|█| 637/637 [01:27<00:00,  7.28i\n",
    "# /nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/val: 100%|█| 178/178 [00:35<00:00,  4.95it/\n",
    "# /nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/center/test: 100%|█| 90/90 [00:18<00:00,  4.78it/s"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1669188",
   "metadata": {},
   "source": [
    "# 5.使用yolo生成用来病例测试的数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "387c112b",
   "metadata": {},
   "source": [
    "## 5.1 将plain CT的yolo检测label移回原数据文件夹"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16699f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "from tqdm import tqdm\n",
    "\n",
    "ct_train = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train'\n",
    "ct_val = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val'\n",
    "ct_test = '/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/test'\n",
    "\n",
    "def list_patient(root):\n",
    "    ps = set()\n",
    "    for cate in ['0', '1', '2']:\n",
    "        for img in os.listdir(os.path.join(root, cate)):\n",
    "            ps.add(img.split('_')[0])\n",
    "    return ps\n",
    "    \n",
    "train = list_patient(ct_train)\n",
    "val = list_patient(ct_val)\n",
    "test = list_patient(ct_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc0eae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train), len(val), len(test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d91588",
   "metadata": {},
   "outputs": [],
   "source": [
    "folders = {'/nfs3-p2/zsxm/dataset/2021-07-23-10-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-23-4':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-30':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-08':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-13':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-17-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-19':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-28':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-29-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-aa':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-pau':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-pau':list(),    \n",
    "          }\n",
    "for k, v in folders.items():\n",
    "    v.extend(sorted(list(filter(lambda x: os.path.isdir(os.path.join(k, x)), os.listdir(k)))))\n",
    "    \n",
    "origins = {'/nfs3-p1/zsxm/dataset/aorta_ct_img_label/negative':list(),\n",
    "           '/nfs3-p1/zsxm/dataset/aorta_ct_img_label/positive':list(),\n",
    "           '/nfs3-p1/zsxm/dataset/aorta_ct_img_label/positive2':list(),\n",
    "           '/nfs3-p1/zsxm/dataset/aorta_ct_img_label/imh':list(),\n",
    "           '/nfs3-p1/zsxm/dataset/aorta_ct_img_label/imh2':list(),\n",
    "          }\n",
    "for k, v in origins.items():\n",
    "    v.extend(sorted(list(filter(lambda x: os.path.isdir(os.path.join(k, x)), os.listdir(k)))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c919aef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_exist(dataset):\n",
    "    p_count = 0\n",
    "    for patient in dataset:\n",
    "        f_count, o_count = 0, 0\n",
    "        for folder, content in folders.items():\n",
    "            if patient in content:\n",
    "                f_count += 1\n",
    "        for folder, content in origins.items():\n",
    "            if patient in content:\n",
    "                o_count += 1\n",
    "        if f_count == 1 and o_count == 1:\n",
    "            p_count += 1\n",
    "    print(len(dataset), p_count)\n",
    "    \n",
    "check_exist(train)\n",
    "check_exist(val)\n",
    "check_exist(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0954625",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def copy_back(dataset):\n",
    "    print('*******************************************************************************************************************')\n",
    "    for patient in tqdm(dataset):\n",
    "        origin_path = []\n",
    "        dst_path = []\n",
    "        for folder, content in origins.items():\n",
    "            if patient in content:\n",
    "                origin_path.append(folder)\n",
    "        for folder, content in folders.items():\n",
    "            if patient in content:\n",
    "                dst_path.append(folder)\n",
    "        assert len(origin_path) == len(dst_path) == 1, f'{patient}:{origin_path}:{dst_path}'\n",
    "        origin_path = origin_path[0]\n",
    "        dst_path = dst_path[0]\n",
    "        \n",
    "        pati_ori_labels = os.path.join(origin_path, patient, 'labels')\n",
    "        pati_dst_labels = os.path.join(dst_path, patient, '1', 'labels')\n",
    "        \n",
    "        assert os.path.exists(pati_ori_labels), f'{patient} pati_ori_labels not exists'\n",
    "        assert not os.path.exists(pati_dst_labels), f'{patient} pati_dst_labels exists'\n",
    "        \n",
    "        shutil.copytree(pati_ori_labels, pati_dst_labels)\n",
    "    print('*******************************************************************************************************************')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5cabc57",
   "metadata": {},
   "source": [
    "## 5.2 对于所有CT和CTA，把检测结果全部切出来并存储"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a63a2edb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_img_exists(dataset):\n",
    "    for patient in tqdm(dataset):\n",
    "        folder_path = []\n",
    "        for folder, content in folders.items():\n",
    "            if patient in content:\n",
    "                folder_path.append(folder)\n",
    "        assert len(folder_path) == 1, f'{patient}:{folder_path}'\n",
    "        folder_path = folder_path[0]\n",
    "        patient_path = os.path.join(folder_path, patient)\n",
    "        \n",
    "        assert os.path.exists(os.path.join(patient_path, '1', 'images_-100_500')), f\"{patient} 1\"\n",
    "        if os.path.exists(os.path.join(patient_path, '2')):\n",
    "            assert os.path.exists(os.path.join(patient_path, '2', 'images_-100_500')), f\"{patient} 2\"\n",
    "        else:\n",
    "            print('2 not exists:', patient)\n",
    "\n",
    "check_img_exists(train)\n",
    "check_img_exists(val)\n",
    "check_img_exists(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a2b37ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "def total_crop(dataset):\n",
    "    dataset = sorted(list(dataset))\n",
    "    for kkk, patient in enumerate(tqdm(dataset)):\n",
    "        folder_path = []\n",
    "        for folder, content in folders.items():\n",
    "            if patient in content:\n",
    "                folder_path.append(folder)\n",
    "        assert len(folder_path) == 1, f'{patient}:{folder_path}'\n",
    "        folder_path = folder_path[0]\n",
    "        patient_path = os.path.join(folder_path, patient)\n",
    "        \n",
    "        for cate in ['1', '2']:\n",
    "            if cate == '2' and not os.path.exists(os.path.join(patient_path, cate)):\n",
    "                print('ignonre 2:', patient)\n",
    "                continue\n",
    "            \n",
    "            conti_thresh = 3 if cate == '1' else 6\n",
    "            cate_path = os.path.join(patient_path, cate)\n",
    "            image_path = os.path.join(cate_path, 'images_-100_500')\n",
    "            label_path = os.path.join(cate_path, 'labels')\n",
    "            total_crop_path = os.path.join(cate_path, 'total_crops_-100_500')\n",
    "            if os.path.exists(total_crop_path):\n",
    "                shutil.rmtree(total_crop_path)\n",
    "            os.mkdir(total_crop_path)\n",
    "            \n",
    "            branch_start, branch_end = [], []\n",
    "            pre_aorta_num = 1\n",
    "            same_aorta_count = 0\n",
    "            has_change_flag = False\n",
    "            for i, label in enumerate(sorted(os.listdir(label_path))):\n",
    "                with open(os.path.join(label_path, label), 'r') as f:\n",
    "                    lines = f.readlines()\n",
    "                num = len(lines)\n",
    "                assert 0<num<3, f'{os.path.join(label_path, label)}:{num}'\n",
    "                if num == pre_aorta_num:\n",
    "                    same_aorta_count += 1\n",
    "                    if has_change_flag and same_aorta_count == conti_thresh:\n",
    "                        if num == 2:\n",
    "                            branch_start.append(i-conti_thresh+1)\n",
    "                        else:\n",
    "                            branch_end.append(i-conti_thresh+1)\n",
    "                else:\n",
    "                    has_change_flag = True\n",
    "                    pre_aorta_num = num\n",
    "                    same_aorta_count = 1\n",
    "            if cate == '1':\n",
    "                #assert len(branch_start) > 0 and len(branch_end) > 0, f'{patient}*{cate}:{branch_start}:{branch_end}'\n",
    "                if not (len(branch_start) > 0 and len(branch_end) > 0):\n",
    "                    print(f'manual set branch_start 1:{patient}*{cate}:{branch_start}:{branch_end}')\n",
    "                    branch_start = [0]\n",
    "                    branch_end = [len(os.listdir(label_path))-1]\n",
    "            else:\n",
    "                if not (len(branch_start) > 0 and len(branch_end) > 0):\n",
    "                    print(f'ignore {patient}*{cate}:{branch_start}:{branch_end}')\n",
    "                    continue\n",
    "            assert min(branch_start) >= 0, f'{patient}: {branch_start}'\n",
    "            assert max(branch_end) < len(os.listdir(label_path)), f'{patient}: {branch_end}---{len(os.listdir(label_path))}'\n",
    "            \n",
    "            branch_start = min(branch_start)\n",
    "            branch_end = min(filter(lambda x: x>branch_start, branch_end))\n",
    "            \n",
    "            j_count, s_count = 0, 0\n",
    "            for i, label in enumerate(sorted(os.listdir(label_path))):\n",
    "                if i < branch_start:\n",
    "                    continue\n",
    "                \n",
    "                with open(os.path.join(label_path, label), 'r') as f:\n",
    "                    lines = f.readlines()\n",
    "                    \n",
    "                jx, jy, jw, jh = -1, -1, -1, -1\n",
    "                sx, sy, sw, sh = -1, -1, -1, -1\n",
    "                pre_jx, pre_jy = 2, 2\n",
    "                if len(lines) == 1:\n",
    "                    corr = list(map(lambda x: float(x), lines[0].split()))\n",
    "                    jx, jy, jw, jh = corr[1], corr[2], corr[3], corr[4]\n",
    "                    if not (0.25 < jx < 0.75 and 0.15 < jy < 0.85):\n",
    "                        jx, jy, jw, jh = -1, -1, -1, -1\n",
    "                    else:\n",
    "                        pre_jx, pre_jy = jx, jy\n",
    "                else:\n",
    "                    corr = [list(map(lambda x: float(x), lines[0].split())), list(map(lambda x: float(x), lines[1].split()))]\n",
    "                    if i in range(branch_start, branch_end):\n",
    "                        sx, sy, sw, sh = (corr[0][1], corr[0][2], corr[0][3], corr[0][4]) if corr[0][2] <= corr[1][2] \\\n",
    "                            else (corr[1][1], corr[1][2], corr[1][3], corr[1][4])\n",
    "                        jx, jy, jw, jh = (corr[0][1], corr[0][2], corr[0][3], corr[0][4]) if corr[0][2] > corr[1][2] \\\n",
    "                            else (corr[1][1], corr[1][2], corr[1][3], corr[1][4])\n",
    "                        if not (0.25 < sx < 0.75 and 0.15 < sy < 0.85):\n",
    "                            sx, sy, sw, sh = -1, -1, -1, -1\n",
    "                        if not (0.25 < jx < 0.75 and 0.15 < jy < 0.85):\n",
    "                            jx, jy, jw, jh = -1, -1, -1, -1\n",
    "                    else:\n",
    "                        min_dis, min_c = float('inf'), -1\n",
    "                        for c in range(2):\n",
    "                            dis = (corr[c][1]-pre_jx)**2 + (corr[c][2]-pre_jy)**2\n",
    "                            if dis < min_dis:\n",
    "                                min_dis = dis\n",
    "                                min_c = c\n",
    "                        jx, jy, jw, jh = corr[min_c][1], corr[min_c][2], corr[min_c][3], corr[min_c][4]\n",
    "                        if not (0.25 < jx < 0.75 and 0.15 < jy < 0.85):\n",
    "                            jx, jy, jw, jh = -1, -1, -1, -1\n",
    "                        else:\n",
    "                            pre_jx, pre_jy = jx, jy\n",
    "                    \n",
    "                if jx != -1 or sx != -1:\n",
    "                    imgname = os.path.splitext(label)[0]+'.png'\n",
    "                    img = np.array(Image.open(os.path.join(image_path, imgname)))\n",
    "                    height, width = img.shape[0], img.shape[1]\n",
    "                else:\n",
    "                    continue\n",
    "                \n",
    "                if jx != -1:\n",
    "                    jw, jh = int(width*jw), int(height*jh)\n",
    "                    jw = jh = max(jw, jh)\n",
    "                    jx1, jy1, jx2, jy2 = int(width*jx-jw/2), int(height*jy-jh/2), int(width*jx+jw/2+1), int(height*jy+jh/2+1)\n",
    "                    j_crop = Image.fromarray(img[jy1:jy2, jx1:jx2])\n",
    "                    j_crop.save(os.path.join(total_crop_path, f'j_{j_count:04d}_{imgname}'))\n",
    "                    j_count += 1\n",
    "                if sx != -1:\n",
    "                    sw, sh = int(width*sw), int(height*sh)\n",
    "                    sw = sh = max(sw, sh)\n",
    "                    sx1, sy1, sx2, sy2 = int(width*sx-sw/2), int(height*sy-sh/2), int(width*sx+sw/2+1), int(height*sy+sh/2+1)\n",
    "                    s_crop = Image.fromarray(img[sy1:sy2, sx1:sx2])\n",
    "                    s_crop.save(os.path.join(total_crop_path, f's_{s_count:04d}_{imgname}'))\n",
    "                    s_count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68bde809",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "total_crop(val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "024cfbe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_crop(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56787d31",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_crop(train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85e314bc",
   "metadata": {},
   "source": [
    "## 5.3 统一移动"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dca8ab15",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "from tqdm import tqdm\n",
    "\n",
    "ct_train = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/train'\n",
    "ct_val = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/val'\n",
    "ct_test = '/nfs3-p1/zsxm/dataset/aorta_classify_cta_-100_500/test'\n",
    "\n",
    "def list_patient(root):\n",
    "    ps = [set(), set(), set()]\n",
    "    for cate in range(3):\n",
    "        for img in os.listdir(os.path.join(root, str(cate))):\n",
    "            ps[cate].add(img.split('_')[0])\n",
    "    return ps\n",
    "\n",
    "train = list_patient(ct_train)\n",
    "val = list_patient(ct_val)\n",
    "test = list_patient(ct_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5909570d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_dataset(dataset):\n",
    "    for d in dataset:\n",
    "        print(len(d), end=' ')\n",
    "    print('')\n",
    "\n",
    "#ct: train/val/test: neg,ad,imh\n",
    "'''\n",
    "333 276 81 \n",
    "94 74 20 \n",
    "46 38 11 \n",
    "'''\n",
    "#cta: train/val/test: neg,ad,imh\n",
    "'''\n",
    "280 276 81 \n",
    "84 74 20 \n",
    "41 38 11 \n",
    "'''\n",
    "print_dataset(train)\n",
    "print_dataset(val)\n",
    "print_dataset(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9333e658",
   "metadata": {},
   "outputs": [],
   "source": [
    "folders = {'/nfs3-p2/zsxm/dataset/2021-07-23-10-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-23-4':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-07-30':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-08':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-13':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-17-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-19':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-28':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-09-29-negative':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-aa':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-10-19-pau':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-imh':list(),\n",
    "           '/nfs3-p2/zsxm/dataset/2021-11-20-pau':list(),    \n",
    "          }\n",
    "for k, v in folders.items():\n",
    "    v.extend(sorted(list(filter(lambda x: os.path.isdir(os.path.join(k, x)), os.listdir(k)))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c8e753b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def copy_together(dataset, root, ct):\n",
    "    for i, cdata in enumerate(dataset):\n",
    "        cate_path = os.path.join(root, str(i))\n",
    "        os.makedirs(cate_path)\n",
    "        for patient in tqdm(cdata):\n",
    "            folder_path = []\n",
    "            for folder, content in folders.items():\n",
    "                if patient in content:\n",
    "                    folder_path.append(folder)\n",
    "            assert len(folder_path) == 1, f'{patient}:{folder_path}'\n",
    "            folder_path = os.path.join(folder_path[0], patient, ct, 'total_crops_-100_500')\n",
    "\n",
    "            assert (ct == '1' and os.path.exists(folder_path)) or ct == '2', f'{i}:{patient}'\n",
    "                \n",
    "            if not os.path.exists(folder_path) or len(os.listdir(folder_path)) == 0:\n",
    "                print('ignore:', ct, patient, )\n",
    "                continue\n",
    "\n",
    "            shutil.copytree(folder_path, os.path.join(cate_path, patient))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1730130c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#由于cta数据集比ct少是因为9.17的后面部分没有标注，因此生成数据时跳过了，而不是没有增强数据，因此要使用cta数据集直接统计而不是直接把1更改2\n",
    "# copy_together(val, '/nfs3-p2/zsxm/dataset/scan_ct/train', '1')\n",
    "# copy_together(test, '/nfs3-p2/zsxm/dataset/scan_cta/train', '2')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98c74038",
   "metadata": {},
   "source": [
    "## 5.4 病例判断"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f95585a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "from types import SimpleNamespace\n",
    "from collections import OrderedDict\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset\n",
    "from torch import optim\n",
    "import torchvision.transforms as T\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from data.datasets import AortaDataset3DCenter, get_weight_list\n",
    "from utils.ranger import Ranger\n",
    "from utils.lr_scheduler import CosineAnnealingWithWarmUpLR\n",
    "from model.resnet3d import resnet3d\n",
    "from model.SupCon import resnet\n",
    "import data.transforms as MT\n",
    "\n",
    "from scipy import integrate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a60456a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:2')\n",
    "load_model = '.details151/CE3D/12-05_183652_3D,有Sobel,只分类阴性和夹层/Net_best.pth'\n",
    "net = resnet3d(34, n_channels=2, n_classes=2, conv1_t_size=3)\n",
    "net.load_state_dict(torch.load(load_model, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3732404d",
   "metadata": {},
   "outputs": [],
   "source": [
    "##直接把整个病例拼起来送到网络里\n",
    "# class PatientDataset(Dataset):\n",
    "#     def __init__(self, root, transform):\n",
    "#         self.datas = []\n",
    "#         self.transform = transform\n",
    "#         for i in ['0', '1']:\n",
    "#             cpath = os.path.join(root, i)\n",
    "#             i = int(i)\n",
    "#             for patient in sorted(os.listdir(cpath)):\n",
    "#                 ppath = os.path.join(cpath, patient)\n",
    "#                 imgs = os.listdir(ppath)\n",
    "#                 pd = {}\n",
    "#                 pd['s'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('s_'), imgs))))\n",
    "#                 pd['j'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('j_'), imgs))))\n",
    "# #                 pd['s'] = [pd['s'][0]]*3 + pd['s'] + [pd['s'][-1]]*3\n",
    "# #                 pd['j'] = [pd['j'][0]]*3 + pd['j'] + [pd['j'][-1]]*3\n",
    "#                 self.datas.append([patient, pd, i])\n",
    "        \n",
    "#     def __len__(self):\n",
    "#         return len(self.datas)\n",
    "        \n",
    "#     def __getitem__(self, index):\n",
    "#         patient, pd, label = self.datas[index]\n",
    "#         simgs = list(map(lambda x: Image.open(x), pd['s']))\n",
    "#         jimgs = list(map(lambda x: Image.open(x), pd['j']))\n",
    "#         simgs = self.transform(simgs)\n",
    "#         jimgs = self.transform(jimgs)\n",
    "#         simgs = torch.stack(simgs, dim=1)\n",
    "#         jimgs = torch.stack(jimgs, dim=1)\n",
    "#         return simgs, jimgs, label, patient\n",
    "    \n",
    "# vt_list = [\n",
    "#     MT.Resize3D(81),\n",
    "#     MT.CenterCrop3D(81),\n",
    "#     MT.ToTensor3D(),\n",
    "#     MT.SobelChannel(3, flag_3d=True)\n",
    "# ]\n",
    "# transform = T.Compose(vt_list)\n",
    "# dataset = PatientDataset('/nfs3-p2/zsxm/dataset/scan_ct/test', transform)\n",
    "\n",
    "# true_list = []\n",
    "# pred_list = []\n",
    "# pred_ori_list = []\n",
    "# with torch.no_grad():\n",
    "#     with tqdm(total=len(dataset), desc=f'test round', unit='img', leave=False) as pbar:\n",
    "#         for simgs, jimgs, label, patient in dataset:\n",
    "#             simgs, jimgs, = simgs.to(device), jimgs.to(device)\n",
    "#             true_list.append(label)\n",
    "#             preds = net(torch.cat([simgs, jimgs], dim=1).unsqueeze(0))\n",
    "#             pred_idx = torch.softmax(preds, dim=1)\n",
    "#             pred_ori_list += pred_idx.tolist()\n",
    "#             pred_idx = pred_idx.argmax(dim=1)\n",
    "#             pred_idx = pred_idx.tolist()\n",
    "#             pred_list.extend(pred_idx)\n",
    "#             pbar.update(1)\n",
    "\n",
    "# AP = []\n",
    "# for c in range(2):\n",
    "#     c_true_list = [int(item==c) for item in true_list]\n",
    "#     c_pred_ori_list = [item[c] for item in pred_ori_list]\n",
    "#     AP.append(metrics.average_precision_score(c_true_list, c_pred_ori_list))\n",
    "\n",
    "# print(f'report:\\n'+metrics.classification_report(true_list, pred_list, digits=4))\n",
    "\n",
    "# print(float(np.mean(AP)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7b943319",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:29<00:00,  2.81it/s]\n",
      "3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.80it/s]\n",
      "4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.74it/s]\n",
      "5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "6: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.75it/s]\n",
      "7: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.76it/s]\n",
      "8: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:29<00:00,  2.81it/s]\n",
      "9: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.79it/s]\n",
      "10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.77it/s]\n",
      "11: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.76it/s]\n",
      "12: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.77it/s]\n",
      "13: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.75it/s]\n",
      "14: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.74it/s]\n",
      "15: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.71it/s]\n",
      "16: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "17: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "18: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]\n",
      "19: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]\n",
      "21: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.74it/s]\n",
      "22: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "23: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]\n",
      "24: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]\n",
      "25: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.71it/s]\n",
      "26: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "27: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]\n",
      "28: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.73it/s]\n",
      "29: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]\n",
      "30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:30<00:00,  2.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 0.9404761904761905\n",
      "report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9767    0.9130    0.9438        46\n",
      "           1     0.9024    0.9737    0.9367        38\n",
      "\n",
      "    accuracy                         0.9405        84\n",
      "   macro avg     0.9396    0.9434    0.9403        84\n",
      "weighted avg     0.9431    0.9405    0.9406        84\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#各个切片分开送入后按阈值统计\n",
    "N_MAX = 30\n",
    "\n",
    "class PatientDataset(Dataset):\n",
    "    def __init__(self, root, transform):\n",
    "        self.datas = []\n",
    "        self.transform = transform\n",
    "        for i in ['0', '1']:\n",
    "            cpath = os.path.join(root, i)\n",
    "            i = int(i)\n",
    "            for patient in sorted(os.listdir(cpath)):\n",
    "                ppath = os.path.join(cpath, patient)\n",
    "                imgs = os.listdir(ppath)\n",
    "                pd = {}\n",
    "                pd['j'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('j_'), imgs))))\n",
    "                pd['s'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('s_'), imgs))))\n",
    "                pd['j'] = [pd['j'][0]]*3 + pd['j'] + [pd['j'][-1]]*3\n",
    "                pd['s'] = [pd['s'][0]]*3 + pd['s'] + [pd['s'][-1]]*3\n",
    "                self.datas.append([patient, pd, i])\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        patient, pd, label = self.datas[index]\n",
    "        jimgs = list(map(lambda x: Image.open(x), pd['j']))\n",
    "        simgs = list(map(lambda x: Image.open(x), pd['s']))\n",
    "        jimgs = self.transform(jimgs)\n",
    "        simgs = self.transform(simgs)\n",
    "        return jimgs, simgs, label, patient\n",
    "    \n",
    "vt_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "transform = T.Compose(vt_list)\n",
    "dataset = PatientDataset('/nfs3-p2/zsxm/dataset/scan_ct/test', transform)\n",
    "\n",
    "max_acc = -1\n",
    "max_n = -1\n",
    "max_true_list = []\n",
    "max_pred_list = []\n",
    "for N in range(1, N_MAX+1):\n",
    "    true_list = []\n",
    "    pred_list = []\n",
    "    with torch.no_grad():\n",
    "        for jimgs, simgs, label, patient in tqdm(dataset, desc=f'{N}'):\n",
    "            true_list.append(label)\n",
    "\n",
    "            jimgs = [torch.stack(jimgs[i:i+7], dim=1) for i in range(len(jimgs)-6)]\n",
    "            simgs = [torch.stack(simgs[i:i+7], dim=1) for i in range(len(simgs)-6)]\n",
    "            imgs = torch.stack(jimgs+simgs, dim=0).to(device)\n",
    "\n",
    "            preds = net(imgs)\n",
    "            preds_idx = preds.argmax(dim=1).tolist()\n",
    "            preds_idx.insert(len(jimgs), -2) #中断降主动脉和升主动脉之间的预测序列\n",
    "\n",
    "            pred_label = 0\n",
    "            pre = -1\n",
    "            max_len = -1\n",
    "            start_idx = -1\n",
    "            for i in range(len(preds_idx)):\n",
    "                cur = preds_idx[i]\n",
    "                if cur != pre:\n",
    "                    if pre in [1, 2]:\n",
    "                        ln = i - start_idx\n",
    "                        if ln >= N and ln > max_len:\n",
    "                            max_len = ln\n",
    "                            pred_label = pre\n",
    "                    start_idx = i\n",
    "                pre = cur\n",
    "\n",
    "            pred_list.append(pred_label)\n",
    "            #print(patient, label, preds_idx, pred_label, max_len, '\\n')\n",
    "    acc = metrics.accuracy_score(true_list, pred_list)\n",
    "    if acc > max_acc:\n",
    "        max_acc = acc\n",
    "        max_n = N\n",
    "        max_true_list = true_list\n",
    "        max_pred_list = pred_list\n",
    "\n",
    "print(max_n, max_acc)\n",
    "print(f'report:\\n'+metrics.classification_report(max_true_list, max_pred_list, digits=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8cb745ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1/5: 100%|████████| 79/79 [01:25<00:00,  1.09s/it]\n",
      "2/5: 100%|████████| 79/79 [01:11<00:00,  1.11it/s]\n",
      "3/5: 100%|████████| 79/79 [01:08<00:00,  1.16it/s]\n",
      "4/5: 100%|████████| 79/79 [01:08<00:00,  1.14it/s]\n",
      "5/5: 100%|████████| 79/79 [01:10<00:00,  1.12it/s]\n"
     ]
    }
   ],
   "source": [
    "#统计四个统计量\n",
    "N_MAX = 20\n",
    "\n",
    "class PatientDataset(Dataset):\n",
    "    def __init__(self, root, transform):\n",
    "        self.datas = []\n",
    "        self.transform = transform\n",
    "        for i in ['0', '1']:\n",
    "            cpath = os.path.join(root, i)\n",
    "            i = int(i)\n",
    "            for patient in sorted(os.listdir(cpath)):\n",
    "                ppath = os.path.join(cpath, patient)\n",
    "                imgs = os.listdir(ppath)\n",
    "                pd = {}\n",
    "                pd['j'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('j_'), imgs))))\n",
    "                pd['s'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('s_'), imgs))))\n",
    "                pd['j'] = [pd['j'][0]]*3 + pd['j'] + [pd['j'][-1]]*3\n",
    "                pd['s'] = [pd['s'][0]]*3 + pd['s'] + [pd['s'][-1]]*3\n",
    "                self.datas.append([patient, pd, i])\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        patient, pd, label = self.datas[index]\n",
    "        jimgs = list(map(lambda x: Image.open(x), pd['j']))\n",
    "        simgs = list(map(lambda x: Image.open(x), pd['s']))\n",
    "        jimgs = self.transform(jimgs)\n",
    "        simgs = self.transform(simgs)\n",
    "        return jimgs, simgs, label, patient\n",
    "    \n",
    "vt_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "transform = T.Compose(vt_list)\n",
    "dataset = PatientDataset('/nfs3-p2/zsxm/dataset/scan_cta/test', transform)\n",
    "\n",
    "weight_list = ['.details_CTA/checkpoints/CE3D/12-08_19:27:22_CTA3D,有Sobel,只分类阴性和夹层,1/Net_best.pth',\n",
    "               '.details_CTA/checkpoints/CE3D/12-08_20:09:24_CTA3D,有Sobel,只分类阴性和夹层,2/Net_best.pth',\n",
    "               '.details_CTA/checkpoints/CE3D/12-08_20:10:21_CTA3D,有Sobel,只分类阴性和夹层,3/Net_best.pth',\n",
    "               '.details_CTA/checkpoints/CE3D/12-08_20:10:52_CTA3D,有Sobel,只分类阴性和夹层,4/Net_best.pth',\n",
    "               '.details_CTA/checkpoints/CE3D/12-08_20:11:15_CTA3D,有Sobel,只分类阴性和夹层,5/Net_best.pth',\n",
    "              ]\n",
    "device = torch.device('cuda:2')\n",
    "net = None\n",
    "exp_auc_list, exp_acc_list, exp_sen_list, exp_spe_list = [], [], [], []\n",
    "for iw, weight in enumerate(weight_list):\n",
    "    del net\n",
    "    net = resnet3d(34, n_channels=2, n_classes=2, conv1_t_size=3)\n",
    "    net.load_state_dict(torch.load(weight, map_location=device))\n",
    "    net.to(device)\n",
    "    net.eval()\n",
    "    \n",
    "    true_list = []\n",
    "    patient_pred_list = []\n",
    "    with torch.no_grad():\n",
    "        for jimgs, simgs, label, patient in tqdm(dataset, desc=f'{iw+1}/{len(weight_list)}', leave=True, ncols=50):\n",
    "            true_list.append(label)\n",
    "\n",
    "            jimgs = [torch.stack(jimgs[i:i+7], dim=1) for i in range(len(jimgs)-6)]\n",
    "            simgs = [torch.stack(simgs[i:i+7], dim=1) for i in range(len(simgs)-6)]\n",
    "            imgs = torch.stack(jimgs+simgs, dim=0).to(device)\n",
    "\n",
    "            preds = torch.softmax(net(imgs),dim=1)[:,1].tolist()\n",
    "            preds.insert(len(jimgs), -2) #中断降主动脉和升主动脉之间的预测序列\n",
    "            preds = np.array(preds)\n",
    "            patient_pred_list.append(preds)\n",
    "    true_list = np.array(true_list)\n",
    "    \n",
    "    auc_list, acc_list, sen_list, spe_list = [], [], [], []\n",
    "    for N in range(1, N_MAX+1):\n",
    "        fpr, tpr = [0.], [0.]\n",
    "        for thresh in [i/1000 for i in range(1000,0,-1)]:\n",
    "            pred_list = []\n",
    "            for preds in patient_pred_list:\n",
    "                preds_idx = (preds > thresh).astype(np.int64).tolist()\n",
    "\n",
    "                pred_label = 0\n",
    "                pre = -1\n",
    "                max_len = -1\n",
    "                start_idx = -1\n",
    "                for i in range(len(preds_idx)):\n",
    "                    cur = preds_idx[i]\n",
    "                    if cur != pre:\n",
    "                        if pre in [1, 2]:\n",
    "                            ln = i - start_idx\n",
    "                            if ln >= N and ln > max_len:\n",
    "                                max_len = ln\n",
    "                                pred_label = pre\n",
    "                        start_idx = i\n",
    "                    pre = cur\n",
    "                pred_list.append(pred_label)\n",
    "            sen = metrics.recall_score(true_list, pred_list, pos_label=1)\n",
    "            spe = metrics.recall_score(true_list, pred_list, pos_label=0)\n",
    "            if thresh == 0.5:\n",
    "                acc_list.append(metrics.accuracy_score(true_list, pred_list))\n",
    "                sen_list.append(sen)\n",
    "                spe_list.append(spe)\n",
    "            fpr.append(1-spe)\n",
    "            tpr.append(sen)\n",
    "        fpr.append(1.), tpr.append(1.)\n",
    "        fpr, tpr = np.array(fpr), np.array(tpr)\n",
    "        sorted_index = np.argsort(fpr)\n",
    "        fpr = fpr[sorted_index]\n",
    "        tpr = tpr[sorted_index]\n",
    "        auc = integrate.trapz(y=tpr, x=fpr)#metrics.auc(fpr, tpr)\n",
    "        auc_list.append(auc)\n",
    "#         plt.plot(fpr, tpr)\n",
    "#         plt.xlim(-0.02, 1.02)\n",
    "#         plt.ylim(-0.02, 1.02)\n",
    "#         plt.show()\n",
    "    exp_auc_list.append(auc_list)\n",
    "    exp_acc_list.append(acc_list)\n",
    "    exp_sen_list.append(sen_list)\n",
    "    exp_spe_list.append(spe_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7ea8f268",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------0----------------\n",
      "1 0.8713±0.0236\n",
      "2 0.9045±0.0195\n",
      "3 0.9241±0.0149\n",
      "4 0.9328±0.0063\n",
      "5 0.9531±0.0103\n",
      "6 0.9612±0.0094\n",
      "7 0.9641±0.0095\n",
      "8 0.9536±0.0061\n",
      "9 0.9566±0.0073\n",
      "10 0.9522±0.0082\n",
      "11 0.9483±0.0078\n",
      "12 0.9455±0.0083\n",
      "13 0.9549±0.0065\n",
      "14 0.9537±0.0054\n",
      "15 0.9548±0.0045\n",
      "16 0.9579±0.0020\n",
      "17 0.9583±0.0018\n",
      "18 0.9588±0.0017\n",
      "19 0.9567±0.0049\n",
      "20 0.9571±0.0050\n",
      "-------------------1----------------\n",
      "1 74.68±2.66\n",
      "2 78.48±2.26\n",
      "3 83.54±2.40\n",
      "4 86.84±2.48\n",
      "5 89.87±2.26\n",
      "6 90.89±2.03\n",
      "7 93.16±1.89\n",
      "8 93.42±0.95\n",
      "9 94.43±0.62\n",
      "10 93.92±0.95\n",
      "11 93.67±0.80\n",
      "12 93.67±0.80\n",
      "13 94.18±0.62\n",
      "14 94.94±0.00\n",
      "15 95.70±0.62\n",
      "16 95.95±0.51\n",
      "17 95.95±0.51\n",
      "18 95.95±0.51\n",
      "19 95.95±0.51\n",
      "20 95.95±0.51\n",
      "-------------------2----------------\n",
      "1 98.42±1.29\n",
      "2 97.37±0.00\n",
      "3 97.37±0.00\n",
      "4 97.37±0.00\n",
      "5 97.37±0.00\n",
      "6 97.37±0.00\n",
      "7 97.37±0.00\n",
      "8 94.74±0.00\n",
      "9 94.74±0.00\n",
      "10 92.63±1.05\n",
      "11 92.11±0.00\n",
      "12 92.11±0.00\n",
      "13 91.58±1.05\n",
      "14 91.58±1.05\n",
      "15 91.58±1.05\n",
      "16 91.58±1.05\n",
      "17 91.58±1.05\n",
      "18 91.58±1.05\n",
      "19 91.58±1.05\n",
      "20 91.58±1.05\n",
      "-------------------3----------------\n",
      "1 52.68±4.78\n",
      "2 60.98±4.36\n",
      "3 70.73±4.63\n",
      "4 77.07±4.78\n",
      "5 82.93±4.36\n",
      "6 84.88±3.90\n",
      "7 89.27±3.65\n",
      "8 92.20±1.83\n",
      "9 94.15±1.19\n",
      "10 95.12±1.54\n",
      "11 95.12±1.54\n",
      "12 95.12±1.54\n",
      "13 96.59±1.19\n",
      "14 98.05±0.98\n",
      "15 99.51±0.98\n",
      "16 100.00±0.00\n",
      "17 100.00±0.00\n",
      "18 100.00±0.00\n",
      "19 100.00±0.00\n",
      "20 100.00±0.00\n"
     ]
    }
   ],
   "source": [
    "exp_auc, exp_acc, exp_sen, exp_spe = np.array(exp_auc_list), np.array(exp_acc_list), np.array(exp_sen_list), np.array(exp_spe_list) \n",
    "for i, metr in enumerate([exp_auc, exp_acc, exp_sen, exp_spe]):\n",
    "    print(f'-------------------{i}----------------')\n",
    "    for n, (avg, std) in enumerate(zip(metr.mean(axis=0), metr.std(axis=0))):\n",
    "        if i == 0:\n",
    "            print(n+1, f'{avg:.4f}±{std:.4f}')\n",
    "        else:\n",
    "            print(n+1, f'{avg*100:.2f}±{std*100:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "9ada1ff5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.78478261 0.82740275 0.85898169 0.87808924 0.89874142 0.92036613\n",
      " 0.93060641 0.94662471 0.94799771 0.95005721 0.94645309 0.94450801\n",
      " 0.94593822 0.94593822 0.94353547 0.93787185 0.94038902 0.94090389\n",
      " 0.93232265 0.9270595 ]\n",
      "[0.03395761 0.03149707 0.02695109 0.00926067 0.01863707 0.01164521\n",
      " 0.01641847 0.00735654 0.00943051 0.00651923 0.00869377 0.01245117\n",
      " 0.00921923 0.00830998 0.00765522 0.01213084 0.00874107 0.00664355\n",
      " 0.01822475 0.01196049]\n",
      "\n",
      "[0.6047619  0.66904762 0.72619048 0.77380952 0.80238095 0.8452381\n",
      " 0.87380952 0.88809524 0.89047619 0.9047619  0.90238095 0.89761905\n",
      " 0.9        0.9047619  0.90952381 0.9047619  0.8952381  0.89285714\n",
      " 0.89285714 0.89285714]\n",
      "[0.02542161 0.02756152 0.03450328 0.02129589 0.01428571 0.01304101\n",
      " 0.02208004 0.01934295 0.01749636 0.01304101 0.01388322 0.02332847\n",
      " 0.02451341 0.02714703 0.02332847 0.01304101 0.01166424 0.00752923\n",
      " 0.         0.        ]\n",
      "\n",
      "[1.         1.         0.99473684 0.98947368 0.97894737 0.97894737\n",
      " 0.97368421 0.96842105 0.95789474 0.95263158 0.93684211 0.92105263\n",
      " 0.91052632 0.91052632 0.89473684 0.87894737 0.85789474 0.85263158\n",
      " 0.84736842 0.84736842]\n",
      "[0.00000000e+00 0.00000000e+00 1.05263158e-02 1.28920513e-02\n",
      " 1.05263158e-02 1.05263158e-02 1.11022302e-16 1.05263158e-02\n",
      " 1.28920513e-02 1.05263158e-02 1.28920513e-02 2.88275030e-02\n",
      " 3.56964736e-02 3.56964736e-02 3.32871333e-02 2.10526316e-02\n",
      " 2.10526316e-02 1.28920513e-02 1.05263158e-02 1.05263158e-02]\n",
      "\n",
      "[0.27826087 0.39565217 0.50434783 0.59565217 0.65652174 0.73478261\n",
      " 0.79130435 0.82173913 0.83478261 0.86521739 0.87391304 0.87826087\n",
      " 0.89130435 0.9        0.92173913 0.92608696 0.92608696 0.92608696\n",
      " 0.93043478 0.93043478]\n",
      "[0.04642208 0.05032973 0.0650723  0.03532191 0.02535196 0.02535196\n",
      " 0.04032008 0.03478261 0.02216965 0.01626808 0.01626808 0.02216965\n",
      " 0.01944407 0.02608696 0.0173913  0.01064996 0.01064996 0.01064996\n",
      " 0.00869565 0.00869565]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#0.4\n",
    "for metr in [exp_auc, exp_acc, exp_sen, exp_spe]:\n",
    "    print(metr.mean(axis=0), end='\\n')\n",
    "    print(metr.std(axis=0), end='\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "b1eec207",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.78478261 0.82740275 0.85898169 0.87808924 0.89874142 0.92036613\n",
      " 0.93060641 0.94662471 0.94799771 0.95005721 0.94645309 0.94450801\n",
      " 0.94593822 0.94593822 0.94353547 0.93787185 0.94038902 0.94090389\n",
      " 0.93232265 0.9270595 ]\n",
      "[0.03395761 0.03149707 0.02695109 0.00926067 0.01863707 0.01164521\n",
      " 0.01641847 0.00735654 0.00943051 0.00651923 0.00869377 0.01245117\n",
      " 0.00921923 0.00830998 0.00765522 0.01213084 0.00874107 0.00664355\n",
      " 0.01822475 0.01196049]\n",
      "\n",
      "[0.61190476 0.68095238 0.73095238 0.78333333 0.81428571 0.85714286\n",
      " 0.88333333 0.9        0.89761905 0.90714286 0.9047619  0.8952381\n",
      " 0.9        0.90714286 0.90714286 0.89761905 0.89047619 0.89047619\n",
      " 0.89285714 0.89285714]\n",
      "[0.02776644 0.03049107 0.03809524 0.01904762 0.01214052 0.01304101\n",
      " 0.01388322 0.01214052 0.0161484  0.01749636 0.01683588 0.02756152\n",
      " 0.02451341 0.01904762 0.01579345 0.01214052 0.00890871 0.00890871\n",
      " 0.00752923 0.00752923]\n",
      "\n",
      "[1.         1.         0.99473684 0.98421053 0.97894737 0.97894737\n",
      " 0.97368421 0.96842105 0.95789474 0.95263158 0.93684211 0.91052632\n",
      " 0.90526316 0.9        0.88421053 0.86315789 0.84736842 0.84210526\n",
      " 0.84210526 0.84210526]\n",
      "[0.00000000e+00 0.00000000e+00 1.05263158e-02 1.28920513e-02\n",
      " 1.05263158e-02 1.05263158e-02 1.11022302e-16 1.05263158e-02\n",
      " 1.28920513e-02 1.05263158e-02 1.28920513e-02 3.93858672e-02\n",
      " 3.56964736e-02 3.86761538e-02 2.68369448e-02 1.96929336e-02\n",
      " 1.05263158e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n",
      "\n",
      "[0.29130435 0.4173913  0.51304348 0.6173913  0.67826087 0.75652174\n",
      " 0.80869565 0.84347826 0.84782609 0.86956522 0.87826087 0.8826087\n",
      " 0.89565217 0.91304348 0.92608696 0.92608696 0.92608696 0.93043478\n",
      " 0.93478261 0.93478261]\n",
      "[0.05070393 0.05567934 0.0709109  0.02948839 0.02129991 0.02535196\n",
      " 0.02535196 0.01626808 0.01944407 0.02381402 0.02216965 0.02608696\n",
      " 0.02129991 0.02381402 0.01064996 0.01064996 0.01064996 0.01626808\n",
      " 0.01374903 0.01374903]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#0.5\n",
    "for metr in [exp_auc, exp_acc, exp_sen, exp_spe]:\n",
    "    print(metr.mean(axis=0), end='\\n')\n",
    "    print(metr.std(axis=0), end='\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d0a9c88c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcQUlEQVR4nO3dd3jV5f3/8eebEPaGsKdsZESN4B61KqgtblBbW0epba2gVrFWra1d2qpoa7XUWjsNoohoUatWqy11hBICYYYdVsKGQCDj/f0j+fUXY0wOcJL7jNfjunJd+ZzzSc6L2+Tlnfvc53zM3RERkfjXKHQAERGJDhW6iEiCUKGLiCQIFbqISIJQoYuIJIjGoR64U6dO3rdv31APLyISl+bPn7/N3dNqui9Yofft25esrKxQDy8iEpfMbN1n3aclFxGRBKFCFxFJECp0EZEEoUIXEUkQKnQRkQRRZ6Gb2TNmVmBmiz/jfjOzx80sz8xyzOz46McUEZG6RDJDfxYYW8v944CBlR+TgCePPpaIiByuOvehu/t7Zta3llPGA3/0ivfh/cDM2plZN3ffHK2QIkfj3eUF/HfdztAxRP4no28HzhhU42uDjko0XljUA9hQ5Ti/8rZPFbqZTaJiFk/v3r2j8NAitZu/bgfXP/sx5Q5modOIVLjpzP4xW+g1/ZrUeNUMd58OTAfIyMjQlTWkXu0tLmFyZjY92jdn7i2n07pZauhIIvUqGoWeD/SqctwT2BSF7ytyVL7/ci6bdxfz/NdPVplLUojGtsU5wLWVu11OAnZr/VxCezl7I7MWbOTbnxvACX3ah44j0iDqnKGb2XPAWUAnM8sHvg+kArj7U8Bc4AIgD9gPXFdfYUUikb9zP/fMXszxvdtx89kDQscRaTCR7HK5qo77HfhW1BKJHIWycue2GQtxh2kTjqNxil47J8kj2NvnitSHp/65io/W7uDhK0bRu2OL0HFEGpSmL5Iwsjfs4tE3V3DRyG5cenyP0HFEGpwKXRJC0cFSpmQuoHPrpvz44hGYNp1LEtKSiySEH76yhHU79vPc106ibQttUZTkpBm6xL3XF29mRtYGvnFmf046pmPoOCLBqNAlrm3ZXcxdsxYxsmdbpnx+UOg4IkGp0CVulZc7t8/M5mBJOdMmpNOksX6cJbnpN0Di1u/+tYZ/523n+18YxjFprULHEQlOhS5xKXfTbh56YxnnH9uFCSf2qvsLRJKACl3izoFDZUzOzKZDyyb87NKR2qIoUknbFiXu/GTuUvIK9vHnG8bQvmWT0HFEYoZm6BJX3l66lT99sI6vnd6P0wZ2Ch1HJKao0CVuFOwt5s4XchjarQ3fOX9w6DgiMUdLLhIX3J07Zuaw72ApmRPTado4JXQkkZijGbrEhT/MW8s/VxTyvQuHMrBL69BxRGKSCl1i3vIte/nJa8s4e3AaXz6pT+g4IjFLhS4xrbikjMmZC2jTrDEPXT5KWxRFaqE1dIlpD72+nGVb9vL7r55IWuumoeOIxDTN0CVmvbeikGf+vYavnNyHs4d0Dh1HJOap0CUm7Sg6xO0zFzKwcyu+e8HQ0HFE4oKWXCTmuDtTX8xh9/4S/nDdaJqlaouiSCQ0Q5eY89xHG3hzyVbuHDuYYd3bhI4jEjdU6BJT8gr28cNXczl9YCeuP7Vf6DgicUWFLjHjUGk5U2YsoHlqCr+4YhSNGmmLosjh0Bq6xIxH3lzB4o17+M2XT6BLm2ah44jEHc3QJSbMW7WN37y3iqtG9+b8Y7uGjiMSl1ToEtyu/Ye4bcZC+nVsyb0XaYuiyJHSkosE5e7c/dIitu07yEvfPJUWTfQjKXKkNEOXoF6Yn8/cRVu47bxBjOjZNnQckbimQpdg1m4r4v45uYzp14Gvn9E/dByRuKdClyBKysqZMiOblEbGoxPSSdEWRZGjFlGhm9lYM1tuZnlmdlcN97c1s1fMbKGZ5ZrZddGPKonkl2+vJHvDLn5y6Qi6t2seOo5IQqiz0M0sBXgCGAcMA64ys2HVTvsWsMTdRwFnAQ+bmS7HLjXKWruDX72Tx2XH9+Sikd1DxxFJGJHM0EcDee6+2t0PAZnA+GrnONDaKq4+0ArYAZRGNakkhD3FJUyZkU3P9i24/4vV5wUicjQiKfQewIYqx/mVt1X1K2AosAlYBEx29/Lq38jMJplZlpllFRYWHmFkiWfffzmXzbuLeXRCOq2bpYaOI5JQIin0mp6t8mrH5wPZQHcgHfiVmX3qbfLcfbq7Z7h7Rlpa2mFGlXj3cvZGXlqwkVs+N5AT+rQPHUck4URS6PlAryrHPamYiVd1HTDLK+QBa4Ah0YkoiWDDjv3c89JiTujTnm+drS2KIvUhkkL/GBhoZv0qn+icCMypds564BwAM+sCDAZWRzOoxK+ycue257NxYNqEdBqnaLesSH2o83XW7l5qZjcDbwApwDPunmtmN1Xe/xTwAPCsmS2iYolmqrtvq8fcEkeefDePj9fu5NEJo+jVoUXoOCIJK6I3znD3ucDcarc9VeXzTcB50Y0miSB7wy4efWslXxzVnYvTqz+XLiLRpL99pd4UHSxlcuYCurZpxgMXD6diV6uI1Be9tZ3Umx+8ksuGHfvJnHQybZtri6JIfdMMXerF3EWbeT4rn2+c1Z/R/TqEjiOSFFToEnWbdx/gu7MWMbJnW6Z8flDoOCJJQ4UuUVVe7tz+/EIOlZbz2MTjSNUWRZEGozV0iaqn/7Waeau28+BlI+jXqWXoOCJJRdMniZrFG3fz8zeWM/bYrlyZ0avuLxCRqFKhS1QcOFTG5MwFdGjZhJ9eOkJbFEUC0JKLRMWP5y5hVWERf7lxDO1b6q3wRULQDF2O2ltLtvLnD9Yz6YxjOHVAp9BxRJKWCl2OSsHeYu58MYdh3dpw+3naoigSkgpdjlh5ufOdmTkUHSzl8avSado4JXQkkaSmQpcj9of/rOW9FYXcc9EwBnRuHTqOSNJTocsRWbZlDz99bRnnDOnMl8b0Dh1HRFChyxEoLilj8nPZtGmWyoOXj9QWRZEYoW2LctgefH0Zy7fu5dnrTqRTq6ah44hIJc3Q5bC8u7yA3/97LV89pS9nDe4cOo6IVKFCl4ht33eQ78zMYVCXVtw1TtcAF4k1WnKRiLg7U1/MYc+BEv50w2iapWqLokis0QxdIvLXj9bz1tICpo4bwtBubULHEZEaqNClTnkF+3jg1SWcPrAT153SN3QcEfkMKnSp1aHSciZnLqB5agoPXzGKRo20RVEkVmkNXWr18JvLyd20h+lfPoHObZqFjiMitdAMXT7TvLxtTH9vNVeP6c15x3YNHUdE6qBClxrt2n+I255fSL9OLbnnwqGh44hIBFTo8inuzndnLWJ70UEen3gcLZpoZU4kHqjQ5VNmzs/ntcVbuP28wQzv0TZ0HBGJkApdPmHttiLun5PLycd0ZNLpx4SOIyKHQYUu/1NSVs7kGdmkpjTi4Su1RVEk3mhxVP7n8bdXsnDDLp64+ni6t2seOo6IHKaIZuhmNtbMlptZnpnd9RnnnGVm2WaWa2b/jG5MqW8frdnBE+/kcfkJPblwZLfQcUTkCNQ5QzezFOAJ4FwgH/jYzOa4+5Iq57QDfg2Mdff1Zqb3VY0je4pLuHVGNj3bt+D+Lx4bOo6IHKFIZuijgTx3X+3uh4BMYHy1c64GZrn7egB3L4huTKlP981ezJY9xUybmE6rplqFE4lXkRR6D2BDleP8ytuqGgS0N7N3zWy+mV1b0zcys0lmlmVmWYWFhUeWWKJq9oKNzM7exORzBnJ87/ah44jIUYik0Gva6uDVjhsDJwAXAucD95rZoE99kft0d89w94y0tLTDDivRtWHHfu6dvZiMPu355ln9Q8cRkaMUyd/X+UCvKsc9gU01nLPN3YuAIjN7DxgFrIhKSom60rJybp2RDcCjE9JpnKIdrCLxLpLf4o+BgWbWz8yaABOBOdXOeRk43cwam1kLYAywNLpRJZqefHcVWet28sDFw+nVoUXoOCISBXXO0N291MxuBt4AUoBn3D3XzG6qvP8pd19qZq8DOUA58LS7L67P4HLkFqzfybS3VzI+vTsXH1f96RARiVfmXn05vGFkZGR4VlZWkMdOZvsOlnLh4+9TWubMnXw6bZunho4kIofBzOa7e0ZN92mPWpL5wZxcNuzYT+akk1XmIglGz4Qlkb/lbGbm/Hy+dfYARvfrEDqOiESZCj1JbNp1gO/OymFUr3bccs7A0HFEpB6o0JNAWblz2/PZlJY7j01IJ1VbFEUSktbQk8Bv31/NB6t38NBlI+nbqWXoOCJSTzRVS3CLN+7m4b8vZ9zwrlyR0TN0HBGpRyr0BHbgUBm3ZC6gY8um/PTSEZjpghUiiUxLLgnsR39bwpptRfzlhjG0a9EkdBwRqWeaoSeoN5ds5S8frmfS6cdwyoBOoeOISANQoSeggj3FTH0xh2O7t+G28z71ppcikqBU6AmmvNy5feZC9h8q5bGJ6TRtnBI6kog0EBV6gnl23lreX7mNey4cxoDOrUPHEZEGpEJPIEs37+Fnry3j80M7c82Y3qHjiEgDU6EniOKSMqZkZtOmeSoPXjZSWxRFkpC2LSaIn722jOVb9/LsdSfSsVXT0HFEJADN0BPAO8sLeHbeWq47tS9nDe4cOo6IBKJCj3Pb9h3kjpk5DO7Smqljh4SOIyIBackljrk7U1/IYU9xCX++cTTNUrVFUSSZaYYex/784XreXlbAXWOHMKRrm9BxRCQwFXqcyivYy49eXcIZg9L46il9Q8cRkRigQo9DB0vLuOW5bFo2bcwvLh9Jo0baoigiWkOPS4/8fQVLNu/ht9dm0LlNs9BxRCRGaIYeZ/6dt43fvLeaa8b05txhXULHEZEYokKPIzuLDnH78ws5Jq0l91w4LHQcEYkxKvQ44e7c/dIithcd5PGJx9G8ibYoisgnqdDjxMysfF5bvIXvnDeY4T3aho4jIjFIhR4H1mwr4v5Xcjmlf0e+dvoxoeOISIxSoce4krJypmQuIDWlEQ9fOUpbFEXkM2nbYox77K2VLMzfzZPXHE+3ts1DxxGRGKYZegz7cPV2nng3jyszejJuRLfQcUQkxqnQY9TuAyXc9vxC+nRowfe/cGzoOCISByIqdDMba2bLzSzPzO6q5bwTzazMzC6PXsTk4+7cM3sxW/YUM23icbRsqpUxEalbnYVuZinAE8A4YBhwlZl96lUtlec9CLwR7ZDJZnb2Rl5ZuIkp5wwkvVe70HFEJE5EMkMfDeS5+2p3PwRkAuNrOO/bwItAQRTzJZ0NO/Zz3+xcTuzbnm+ePSB0HBGJI5EUeg9gQ5Xj/Mrb/sfMegCXAE/V9o3MbJKZZZlZVmFh4eFmTXilZeXcOiMbgEeuTCdFWxRF5DBEUug1tYpXO54GTHX3stq+kbtPd/cMd89IS0uLMGLy+PW7q8hat5MfXTKcXh1ahI4jInEmkmfb8oFeVY57ApuqnZMBZJoZQCfgAjMrdffZ0QiZDP67fiePvb2Si9O7Mz69R91fICJSTSSF/jEw0Mz6ARuBicDVVU9w937/73MzexZ4VWUeuX0HS5mSmU3XNs344cXDQ8cRkThVZ6G7e6mZ3UzF7pUU4Bl3zzWzmyrvr3XdXOp2/5xc8nfuZ8bXT6ZNs9TQcUQkTkW0wdnd5wJzq91WY5G7+1ePPlbyeDVnEy/Mz+eWzw3gxL4dQscRkTimV4oGtGnXAe6etYj0Xu349jkDQ8cRkTinQg+krNy5dUY2ZeXOYxPTSU3RfwoROTp6TXkg099bzYdrdvDzy0fSp2PL0HFEJAFoWhjAovzdPPz35Vw4ohuXn9AzdBwRSRAq9Aa2/1ApkzMXkNa6KT++ZDiVe/dFRI6allwa2AOvLmXN9iL+cuMY2rVoEjqOiCQQzdAb0N9zt/DcR+uZdMYxnNK/U+g4IpJgVOgNpGBPMVNfzGF4jzbcfu7g0HFEJAGp0BtAeblz+8yFHCgpY9qE42jSWMMuItGnZmkAv5+3lvdXbuPei4YxoHOr0HFEJEGp0OvZ0s17ePC1ZXx+aBeuHt07dBwRSWAq9HpUXFLG5MwFtG2RyoOXjdAWRRGpV9q2WI9+9toyVmzdxx+uH03HVk1DxxGRBKcZej15Z1kBz85by/Wn9uPMQbo6k4jUPxV6Pdi27yB3vLCQIV1bc+dYbVEUkYahJZcoc3fufCGHPcWl/OXGk2iWmhI6kogkCc3Qo+zPH6zjH8sKuHvcEAZ3bR06jogkERV6FK3cupcf/W0pZw5K4yun9A0dR0SSjAo9Sg6WlnFLZjatmjbm51eM1BZFEWlwWkOPkl+8sZylm/fwu69k0Ll1s9BxRCQJaYYeBf9auY3fvr+GL53Um3OGdgkdR0SSlAr9KO0sOsTtM7Ppn9aS710wLHQcEUliWnI5Cu7OXbNy2FF0iN995USaN9EWRREJRzP0o/B81gbeyN3KHecPZniPtqHjiEiSU6EfodWF+7h/zhJO6d+RG087JnQcEREV+pEoKStnyoxsmjRuxCNXptOokbYoikh4WkM/AtPeWkFO/m6evOZ4urbVFkURiQ2aoR+mD1dv59fvrmJCRi/GjegWOo6IyP+o0A/D7gMl3Dojmz4dWnDfF7RFUURii5ZcIuTufO+lRRTsPciL3ziFlk01dCISWyKaoZvZWDNbbmZ5ZnZXDfdfY2Y5lR/zzGxU9KOG9dKCjbyas5lbzx3EqF7tQscREfmUOgvdzFKAJ4BxwDDgKjOrvt6wBjjT3UcCDwDTox00pPXb93Pfy7mM7tuBm87sHzqOiEiNIpmhjwby3H21ux8CMoHxVU9w93nuvrPy8AOgZ3RjhlNaVs6UGQswg0cmjCJFWxRFJEZFUug9gA1VjvMrb/ssNwCv1XSHmU0ysywzyyosLIw8ZUBPvLOK/67fxY8uHk7P9i1CxxER+UyRFHpNU1Kv8USzs6ko9Kk13e/u0909w90z0tJi/8LJ89ft5PF/rOSS43owPr22/4eJiIQXyVaNfKBXleOewKbqJ5nZSOBpYJy7b49OvHD2FpcwZcYCurVtxg/GHxs6johInSKZoX8MDDSzfmbWBJgIzKl6gpn1BmYBX3b3FdGP2fDun7OEjTsPMG1COm2apYaOIyJSpzpn6O5eamY3A28AKcAz7p5rZjdV3v8UcB/QEfh15aXXSt09o/5i169XFm7ixf/mc8s5A8no2yF0HBGRiJh7jcvh9S4jI8OzsrKCPHZtNu46wLhp79G/cytmfv1kGqfoxbQiEjvMbP5nTZjVVlWUlTu3zcimrNyZNiFdZS4icUWvX6/iN++t4sM1O/jFFaPo07Fl6DgiIodFU9BKOfm7eOTvK7hwZDcuO15bFEUk/qjQgf2HSpmcmU1a66b85OIRVD6xKyISV7TkAjzw6hLWbi/irzeeRNsW2qIoIvEp6Wfory/ewnMfbeCmM/tzcv+OoeOIiByxpC70rXuKuWtWDiN6tOXWzw8KHUdE5KgkbaGXlzvfmbmQgyXlTJuYTpPGSTsUIpIgkrbFnvn3Gt5fuY17LxpG/7RWoeOIiBy1pCz0JZv28NDryzl3WBeuGt2r7i8QEYkDSVfoxSVlTM5cQNsWqTx42UhtURSRhJF02xZ/OncpKwv28cfrR9OhZZPQcUREoiapZujvLCvgD/9Zxw2n9eOMQbF/gQ0RkcORNIVeuPcgd7ywkCFdW3PH+YNDxxERibqkWHJxd+58YSF7i0v569dOollqSuhIIiJRlxQz9D99sI53lhdy9wVDGdSldeg4IiL1IuELfcXWvfz4b0s5e3Aa157cJ3QcEZF6k9CFfrC0jFueW0Crpo156PJR2qIoIgktodfQf/76cpZt2cszX80grXXT0HFEROpVws7Q319ZyNP/WsO1J/fhc0O6hI4jIlLvErLQdxYd4vbnFzKgcyvuvmBo6DgiIg0i4Qrd3blrVg679pfw2MR0bVEUkaSRcIU+4+MNvJG7lTvOH8yx3duGjiMi0mASqtBXF+7jB68s4dQBHbnhtH6h44iINKiEKfRDpeVMzsymaWojHr4inUaNtEVRRJJLwmxbnPbWChZt3M1TXzqBrm2bhY4jItLgEmKG/sHq7Tz5z1VMPLEXY4d3DR1HRCSIuC/03ftLuG1GNn07tuTei4aFjiMiEkxcL7m4O3fPXkTB3oO8+I1TaNk0rv85IiJHJa5n6LP+u5G/5Wzm1nMHMapXu9BxRESCittCX7e9iPteXszofh246cz+oeOIiAQXUaGb2VgzW25meWZ2Vw33m5k9Xnl/jpkdH/2o/19pWTlTZmTTqJHx6IR0UrRFUUSk7kI3sxTgCWAcMAy4ysyqP/s4DhhY+TEJeDLKOT/hl//IY8H6XfzkkhH0aNe8Ph9KRCRuRDJDHw3kuftqdz8EZALjq50zHvijV/gAaGdm3aKcFYD563bwy3+s5NLje/CFUd3r4yFEROJSJIXeA9hQ5Ti/8rbDPQczm2RmWWaWVVhYeLhZAWiSksJpA9P4wRePPaKvFxFJVJEUek0L1H4E5+Du0909w90z0tLSIsn3KSN6tuWP14+mdbPUI/p6EZFEFUmh5wO9qhz3BDYdwTkiIlKPIin0j4GBZtbPzJoAE4E51c6ZA1xbudvlJGC3u2+OclYREalFnS+tdPdSM7sZeANIAZ5x91wzu6ny/qeAucAFQB6wH7iu/iKLiEhNInqtvLvPpaK0q972VJXPHfhWdKOJiMjhiNtXioqIyCep0EVEEoQKXUQkQajQRUQShFU8nxnggc0KgXVH+OWdgG1RjJNoND610/jUTuNTu9Dj08fda3xlZrBCPxpmluXuGaFzxCqNT+00PrXT+NQulsdHSy4iIglChS4ikiDitdCnhw4Q4zQ+tdP41E7jU7uYHZ+4XEMXEZFPi9cZuoiIVKNCFxFJEDFd6LF2cepYE8H4XFM5LjlmNs/MRoXIGUpd41PlvBPNrMzMLm/IfCFFMjZmdpaZZZtZrpn9s6EzhhTB71ZbM3vFzBZWjk9svMOsu8fkBxVv1bsKOAZoAiwEhlU75wLgNSqumHQS8GHo3DE2PqcA7Ss/H6fx+eT4VDnvH1S8m+jloXPHytgA7YAlQO/K486hc8fY+NwNPFj5eRqwA2gSOnssz9Bj6uLUMajO8XH3ee6+s/LwAyquJJUsIvn5Afg28CJQ0JDhAotkbK4GZrn7egB31/h8kgOtzcyAVlQUemnDxvy0WC70qF2cOkEd7r/9Bir+mkkWdY6PmfUALgGeIrlE8rMzCGhvZu+a2Xwzu7bB0oUXyfj8ChhKxaU2FwGT3b28YeJ9togucBFI1C5OnaAi/reb2dlUFPpp9ZootkQyPtOAqe5eVjHRShqRjE1j4ATgHKA58B8z+8DdV9R3uBgQyficD2QDnwP6A2+a2fvuvqees9UqlgtdF6euXUT/djMbCTwNjHP37Q2ULRZEMj4ZQGZlmXcCLjCzUnef3SAJw4n0d2ubuxcBRWb2HjAKSIZCj2R8rgN+5hWL6HlmtgYYAnzUMBFrFstLLro4de3qHB8z6w3MAr6cJDOrquocH3fv5+593b0v8ALwzSQoc4jsd+tl4HQza2xmLYAxwNIGzhlKJOOznoq/XjCzLsBgYHWDpqxBzM7QXRenrlWE43Mf0BH4deUstNRj9F3ioi3C8UlKkYyNuy81s9eBHKAceNrdF4dL3XAi/Nl5AHjWzBZRsUQz1d2Dv+WwXvovIpIgYnnJRUREDoMKXUQkQajQRUQShApdRCRBqNBFRBKECl1EJEGo0EVEEsT/AXeia+po3SJyAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(fpr, tpr)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85fb7ece",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e0522f09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#单张2D模型\n",
    "device = torch.device('cuda:2')\n",
    "load_model = '.details_CTA/checkpoints/CE/12-13_16:03:28_1CTA,Sobel,只分类阴性和夹层/Net_best.pth'#'.details/checkpoints/MD/09-13_16:23:37_1最优权重正梯度/Net_best.pth'\n",
    "net = resnet(34, n_channels=2, n_classes=2)\n",
    "net.load_state_dict(torch.load(load_model, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fcbdafbc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:34<00:00,  2.32it/s]\n",
      "2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:36<00:00,  2.17it/s]\n",
      "3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:38<00:00,  2.08it/s]\n",
      "4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:38<00:00,  2.05it/s]\n",
      "5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:43<00:00,  1.80it/s]\n",
      "6: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:49<00:00,  1.59it/s]\n",
      "7: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:42<00:00,  1.87it/s]\n",
      "8: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:48<00:00,  1.63it/s]\n",
      "9: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:45<00:00,  1.72it/s]\n",
      "10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:41<00:00,  1.92it/s]\n",
      "11: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:46<00:00,  1.70it/s]\n",
      "12: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:44<00:00,  1.78it/s]\n",
      "13: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:39<00:00,  2.00it/s]\n",
      "14: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:41<00:00,  1.91it/s]\n",
      "15: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:41<00:00,  1.92it/s]\n",
      "16: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:41<00:00,  1.89it/s]\n",
      "17: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:47<00:00,  1.67it/s]\n",
      "18: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:41<00:00,  1.90it/s]\n",
      "19: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:42<00:00,  1.88it/s]\n",
      "20: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:40<00:00,  1.93it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20 0.9493670886075949\n",
      "report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9111    1.0000    0.9535        41\n",
      "           1     1.0000    0.8947    0.9444        38\n",
      "\n",
      "    accuracy                         0.9494        79\n",
      "   macro avg     0.9556    0.9474    0.9490        79\n",
      "weighted avg     0.9539    0.9494    0.9491        79\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "#2D各个切片分开送入后按阈值统计\n",
    "N_MAX = 20\n",
    "class PatientDataset(Dataset):\n",
    "    def __init__(self, root, transform):\n",
    "        self.datas = []\n",
    "        self.transform = transform\n",
    "        for i in ['0', '1']:\n",
    "            cpath = os.path.join(root, i)\n",
    "            i = int(i)\n",
    "            for patient in sorted(os.listdir(cpath)):\n",
    "                ppath = os.path.join(cpath, patient)\n",
    "                imgs = os.listdir(ppath)\n",
    "                pd = {}\n",
    "                pd['j'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('j_'), imgs))))\n",
    "                pd['s'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('s_'), imgs))))\n",
    "                self.datas.append([patient, pd, i])\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        patient, pd, label = self.datas[index]\n",
    "        jimgs = list(map(lambda x: Image.open(x), pd['j']))\n",
    "        simgs = list(map(lambda x: Image.open(x), pd['s']))\n",
    "        jimgs = self.transform(jimgs)\n",
    "        simgs = self.transform(simgs)\n",
    "        return jimgs, simgs, label, patient\n",
    "    \n",
    "vt_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "transform = T.Compose(vt_list)\n",
    "dataset = PatientDataset('/nfs3-p2/zsxm/dataset/scan_cta/test', transform)\n",
    "\n",
    "max_acc = -1\n",
    "max_n = -1\n",
    "max_true_list = []\n",
    "max_pred_list = []\n",
    "for N in (range(1, N_MAX+1)):\n",
    "    true_list = []\n",
    "    pred_list = []\n",
    "    with torch.no_grad():\n",
    "        for jimgs, simgs, label, patient in tqdm(dataset, desc=f'{N}'):\n",
    "            true_list.append(label)\n",
    "\n",
    "            imgs = torch.stack(jimgs+simgs, dim=0).to(device)\n",
    "\n",
    "            preds = net(imgs)\n",
    "            preds_idx = preds.argmax(dim=1).tolist()\n",
    "            preds_idx.insert(len(jimgs), -2) #中断降主动脉和升主动脉之间的预测序列\n",
    "\n",
    "            pred_label = 0\n",
    "            pre = -1\n",
    "            max_len = -1\n",
    "            start_idx = -1\n",
    "            for i in range(len(preds_idx)):\n",
    "                cur = preds_idx[i]\n",
    "                if cur != pre:\n",
    "                    if pre in [1, 2]:\n",
    "                        ln = i - start_idx\n",
    "                        if ln >= N and ln > max_len:\n",
    "                            max_len = ln\n",
    "                            pred_label = pre\n",
    "                    start_idx = i\n",
    "                pre = cur\n",
    "\n",
    "            pred_list.append(pred_label)\n",
    "            #print(patient, label, preds_idx, pred_label, max_len, '\\n')\n",
    "    acc = metrics.accuracy_score(true_list, pred_list)\n",
    "    if acc > max_acc:\n",
    "        max_acc = acc\n",
    "        max_n = N\n",
    "        max_true_list = true_list\n",
    "        max_pred_list = pred_list\n",
    "\n",
    "print(max_n, max_acc)\n",
    "print(f'report:\\n'+metrics.classification_report(max_true_list, max_pred_list, digits=4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6467bb0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(true_list)\n",
    "print(pred_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f16f365e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1/1: 100%|████████| 79/79 [00:55<00:00,  1.43it/s]\n"
     ]
    }
   ],
   "source": [
    "#2D统计四个统计量\n",
    "N_MAX = 50\n",
    "\n",
    "class PatientDataset(Dataset):\n",
    "    def __init__(self, root, transform):\n",
    "        self.datas = []\n",
    "        self.transform = transform\n",
    "        for i in ['0', '1']:\n",
    "            cpath = os.path.join(root, i)\n",
    "            i = int(i)\n",
    "            for patient in sorted(os.listdir(cpath)):\n",
    "                ppath = os.path.join(cpath, patient)\n",
    "                imgs = os.listdir(ppath)\n",
    "                pd = {}\n",
    "                pd['j'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('j_'), imgs))))\n",
    "                pd['s'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('s_'), imgs))))\n",
    "                self.datas.append([patient, pd, i])\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        patient, pd, label = self.datas[index]\n",
    "        jimgs = list(map(lambda x: Image.open(x), pd['j']))\n",
    "        simgs = list(map(lambda x: Image.open(x), pd['s']))\n",
    "        jimgs = self.transform(jimgs)\n",
    "        simgs = self.transform(simgs)\n",
    "        return jimgs, simgs, label, patient\n",
    "    \n",
    "vt_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "transform = T.Compose(vt_list)\n",
    "dataset = PatientDataset('/nfs3-p2/zsxm/dataset/scan_cta/test', transform)\n",
    "\n",
    "weight_list = ['.details_CTA/checkpoints/CE/12-13_16:03:28_1CTA,Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#                '.details_CTA/checkpoints/CE/12-13_16:06:56_2CTA,Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#                '.details_CTA/checkpoints/CE/12-13_16:07:13_3CTA,Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#                '.details_CTA/checkpoints/CE/12-13_16:07:36_4CTA,Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#                '.details_CTA/checkpoints/CE/12-13_16:08:10_5CTA,Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "              ]\n",
    "device = torch.device('cuda:2')\n",
    "net = None\n",
    "exp_auc_list, exp_acc_list, exp_sen_list, exp_spe_list = [], [], [], []\n",
    "for iw, weight in enumerate(weight_list):\n",
    "    del net\n",
    "    net = resnet(34, n_channels=2, n_classes=2)\n",
    "    net.load_state_dict(torch.load(weight, map_location=device))\n",
    "    net.to(device)\n",
    "    net.eval()\n",
    "    \n",
    "    true_list = []\n",
    "    patient_pred_list = []\n",
    "    with torch.no_grad():\n",
    "        for jimgs, simgs, label, patient in tqdm(dataset, desc=f'{iw+1}/{len(weight_list)}', leave=True, ncols=50):\n",
    "            true_list.append(label)\n",
    "            \n",
    "            imgs = torch.stack(jimgs+simgs, dim=0).to(device)\n",
    "\n",
    "            preds = torch.softmax(net(imgs),dim=1)[:,1].tolist()\n",
    "            preds.insert(len(jimgs), -2) #中断降主动脉和升主动脉之间的预测序列\n",
    "            preds = np.array(preds)\n",
    "            patient_pred_list.append(preds)\n",
    "    true_list = np.array(true_list)\n",
    "    \n",
    "    auc_list, acc_list, sen_list, spe_list = [], [], [], []\n",
    "    for N in range(1, N_MAX+1):\n",
    "        fpr, tpr = [0.], [0.]\n",
    "        for thresh in [i/1000 for i in range(1000,0,-1)]:\n",
    "            pred_list = []\n",
    "            for preds in patient_pred_list:\n",
    "                preds_idx = (preds > thresh).astype(np.int64).tolist()\n",
    "\n",
    "                pred_label = 0\n",
    "                pre = -1\n",
    "                max_len = -1\n",
    "                start_idx = -1\n",
    "                for i in range(len(preds_idx)):\n",
    "                    cur = preds_idx[i]\n",
    "                    if cur != pre:\n",
    "                        if pre in [1, 2]:\n",
    "                            ln = i - start_idx\n",
    "                            if ln >= N and ln > max_len:\n",
    "                                max_len = ln\n",
    "                                pred_label = pre\n",
    "                        start_idx = i\n",
    "                    pre = cur\n",
    "                pred_list.append(pred_label)\n",
    "            sen = metrics.recall_score(true_list, pred_list, pos_label=1)\n",
    "            spe = metrics.recall_score(true_list, pred_list, pos_label=0)\n",
    "            if thresh == 0.5:\n",
    "                acc_list.append(metrics.accuracy_score(true_list, pred_list))\n",
    "                sen_list.append(sen)\n",
    "                spe_list.append(spe)\n",
    "            fpr.append(1-spe)\n",
    "            tpr.append(sen)\n",
    "        fpr.append(1.), tpr.append(1.)\n",
    "        fpr, tpr = np.array(fpr), np.array(tpr)\n",
    "        sorted_index = np.argsort(fpr)\n",
    "        fpr = fpr[sorted_index]\n",
    "        tpr = tpr[sorted_index]\n",
    "        auc = integrate.trapz(y=tpr, x=fpr)#metrics.auc(fpr, tpr)\n",
    "        auc_list.append(auc)\n",
    "    exp_auc_list.append(auc_list)\n",
    "    exp_acc_list.append(acc_list)\n",
    "    exp_sen_list.append(sen_list)\n",
    "    exp_spe_list.append(spe_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1af6bea3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------0----------------\n",
      "1 0.7474±0.0000\n",
      "2 0.8533±0.0000\n",
      "3 0.9015±0.0000\n",
      "4 0.9406±0.0000\n",
      "5 0.9339±0.0000\n",
      "6 0.9332±0.0000\n",
      "7 0.9339±0.0000\n",
      "8 0.9451±0.0000\n",
      "9 0.9307±0.0000\n",
      "10 0.9313±0.0000\n",
      "11 0.9442±0.0000\n",
      "12 0.9438±0.0000\n",
      "13 0.9435±0.0000\n",
      "14 0.9442±0.0000\n",
      "15 0.9570±0.0000\n",
      "16 0.9570±0.0000\n",
      "17 0.9580±0.0000\n",
      "18 0.9576±0.0000\n",
      "19 0.9448±0.0000\n",
      "20 0.9448±0.0000\n",
      "21 0.9461±0.0000\n",
      "22 0.9461±0.0000\n",
      "23 0.9461±0.0000\n",
      "24 0.9458±0.0000\n",
      "25 0.9326±0.0000\n",
      "26 0.9326±0.0000\n",
      "27 0.9326±0.0000\n",
      "28 0.9342±0.0000\n",
      "29 0.9342±0.0000\n",
      "30 0.9342±0.0000\n",
      "31 0.9342±0.0000\n",
      "32 0.9211±0.0000\n",
      "33 0.9211±0.0000\n",
      "34 0.9211±0.0000\n",
      "35 0.9211±0.0000\n",
      "36 0.9211±0.0000\n",
      "37 0.9211±0.0000\n",
      "38 0.9211±0.0000\n",
      "39 0.9079±0.0000\n",
      "40 0.9079±0.0000\n",
      "41 0.9079±0.0000\n",
      "42 0.9079±0.0000\n",
      "43 0.9079±0.0000\n",
      "44 0.9079±0.0000\n",
      "45 0.8947±0.0000\n",
      "46 0.8947±0.0000\n",
      "47 0.8947±0.0000\n",
      "48 0.8947±0.0000\n",
      "49 0.8947±0.0000\n",
      "50 0.8947±0.0000\n",
      "-------------------1----------------\n",
      "1 65.82±0.00\n",
      "2 73.42±0.00\n",
      "3 77.22±0.00\n",
      "4 87.34±0.00\n",
      "5 89.87±0.00\n",
      "6 92.41±0.00\n",
      "7 92.41±0.00\n",
      "8 93.67±0.00\n",
      "9 92.41±0.00\n",
      "10 93.67±0.00\n",
      "11 93.67±0.00\n",
      "12 92.41±0.00\n",
      "13 92.41±0.00\n",
      "14 93.67±0.00\n",
      "15 93.67±0.00\n",
      "16 93.67±0.00\n",
      "17 93.67±0.00\n",
      "18 93.67±0.00\n",
      "19 93.67±0.00\n",
      "20 94.94±0.00\n",
      "21 94.94±0.00\n",
      "22 94.94±0.00\n",
      "23 94.94±0.00\n",
      "24 94.94±0.00\n",
      "25 93.67±0.00\n",
      "26 93.67±0.00\n",
      "27 93.67±0.00\n",
      "28 93.67±0.00\n",
      "29 93.67±0.00\n",
      "30 93.67±0.00\n",
      "31 93.67±0.00\n",
      "32 92.41±0.00\n",
      "33 92.41±0.00\n",
      "34 92.41±0.00\n",
      "35 92.41±0.00\n",
      "36 92.41±0.00\n",
      "37 92.41±0.00\n",
      "38 92.41±0.00\n",
      "39 92.41±0.00\n",
      "40 91.14±0.00\n",
      "41 91.14±0.00\n",
      "42 91.14±0.00\n",
      "43 91.14±0.00\n",
      "44 91.14±0.00\n",
      "45 89.87±0.00\n",
      "46 89.87±0.00\n",
      "47 89.87±0.00\n",
      "48 89.87±0.00\n",
      "49 89.87±0.00\n",
      "50 89.87±0.00\n",
      "-------------------2----------------\n",
      "1 100.00±0.00\n",
      "2 97.37±0.00\n",
      "3 97.37±0.00\n",
      "4 97.37±0.00\n",
      "5 97.37±0.00\n",
      "6 94.74±0.00\n",
      "7 94.74±0.00\n",
      "8 94.74±0.00\n",
      "9 92.11±0.00\n",
      "10 92.11±0.00\n",
      "11 92.11±0.00\n",
      "12 89.47±0.00\n",
      "13 89.47±0.00\n",
      "14 89.47±0.00\n",
      "15 89.47±0.00\n",
      "16 89.47±0.00\n",
      "17 89.47±0.00\n",
      "18 89.47±0.00\n",
      "19 89.47±0.00\n",
      "20 89.47±0.00\n",
      "21 89.47±0.00\n",
      "22 89.47±0.00\n",
      "23 89.47±0.00\n",
      "24 89.47±0.00\n",
      "25 86.84±0.00\n",
      "26 86.84±0.00\n",
      "27 86.84±0.00\n",
      "28 86.84±0.00\n",
      "29 86.84±0.00\n",
      "30 86.84±0.00\n",
      "31 86.84±0.00\n",
      "32 84.21±0.00\n",
      "33 84.21±0.00\n",
      "34 84.21±0.00\n",
      "35 84.21±0.00\n",
      "36 84.21±0.00\n",
      "37 84.21±0.00\n",
      "38 84.21±0.00\n",
      "39 84.21±0.00\n",
      "40 81.58±0.00\n",
      "41 81.58±0.00\n",
      "42 81.58±0.00\n",
      "43 81.58±0.00\n",
      "44 81.58±0.00\n",
      "45 78.95±0.00\n",
      "46 78.95±0.00\n",
      "47 78.95±0.00\n",
      "48 78.95±0.00\n",
      "49 78.95±0.00\n",
      "50 78.95±0.00\n",
      "-------------------3----------------\n",
      "1 34.15±0.00\n",
      "2 51.22±0.00\n",
      "3 58.54±0.00\n",
      "4 78.05±0.00\n",
      "5 82.93±0.00\n",
      "6 90.24±0.00\n",
      "7 90.24±0.00\n",
      "8 92.68±0.00\n",
      "9 92.68±0.00\n",
      "10 95.12±0.00\n",
      "11 95.12±0.00\n",
      "12 95.12±0.00\n",
      "13 95.12±0.00\n",
      "14 97.56±0.00\n",
      "15 97.56±0.00\n",
      "16 97.56±0.00\n",
      "17 97.56±0.00\n",
      "18 97.56±0.00\n",
      "19 97.56±0.00\n",
      "20 100.00±0.00\n",
      "21 100.00±0.00\n",
      "22 100.00±0.00\n",
      "23 100.00±0.00\n",
      "24 100.00±0.00\n",
      "25 100.00±0.00\n",
      "26 100.00±0.00\n",
      "27 100.00±0.00\n",
      "28 100.00±0.00\n",
      "29 100.00±0.00\n",
      "30 100.00±0.00\n",
      "31 100.00±0.00\n",
      "32 100.00±0.00\n",
      "33 100.00±0.00\n",
      "34 100.00±0.00\n",
      "35 100.00±0.00\n",
      "36 100.00±0.00\n",
      "37 100.00±0.00\n",
      "38 100.00±0.00\n",
      "39 100.00±0.00\n",
      "40 100.00±0.00\n",
      "41 100.00±0.00\n",
      "42 100.00±0.00\n",
      "43 100.00±0.00\n",
      "44 100.00±0.00\n",
      "45 100.00±0.00\n",
      "46 100.00±0.00\n",
      "47 100.00±0.00\n",
      "48 100.00±0.00\n",
      "49 100.00±0.00\n",
      "50 100.00±0.00\n"
     ]
    }
   ],
   "source": [
    "exp_auc, exp_acc, exp_sen, exp_spe = np.array(exp_auc_list), np.array(exp_acc_list), np.array(exp_sen_list), np.array(exp_spe_list) \n",
    "for i, metr in enumerate([exp_auc, exp_acc, exp_sen, exp_spe]):\n",
    "    print(f'-------------------{i}----------------')\n",
    "    for n, (avg, std) in enumerate(zip(metr.mean(axis=0), metr.std(axis=0))):\n",
    "        if i == 0:\n",
    "            print(n+1, f'{avg:.4f}±{std:.4f}')\n",
    "        else:\n",
    "            print(n+1, f'{avg*100:.2f}±{std*100:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09a5cd52",
   "metadata": {},
   "source": [
    "# 6.实验提高IMH recall的方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dcd5443",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "from types import SimpleNamespace\n",
    "from collections import OrderedDict\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from torch.utils.data import DataLoader, WeightedRandomSampler\n",
    "from torch import optim\n",
    "import torchvision.transforms as T\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from data.datasets import AortaDataset3DCenter, get_weight_list\n",
    "from utils.ranger import Ranger\n",
    "from utils.lr_scheduler import CosineAnnealingWithWarmUpLR\n",
    "from model.resnet3d import resnet3d\n",
    "import data.transforms as MT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bccbba26",
   "metadata": {},
   "outputs": [],
   "source": [
    "vt_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "val_transform = T.Compose(vt_list)\n",
    "val_dataset = AortaDataset3DCenter('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test', transform=val_transform, depth=7, step=1)\n",
    "val_loader = DataLoader(val_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1084516c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:3')\n",
    "load_model = '.details/checkpoints/MD3D/08-28_23:12:08_3.2.最优权重.正梯度/Net_best.pth'\n",
    "net = resnet3d(34, n_channels=2, n_classes=3, conv1_t_size=3)\n",
    "net.load_state_dict(torch.load(load_model, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26c4fdea",
   "metadata": {},
   "outputs": [],
   "source": [
    "tot_loss = 0\n",
    "true_list = []\n",
    "pred_list = []\n",
    "pred_ori_list = []\n",
    "with torch.no_grad():\n",
    "    with tqdm(total=len(val_dataset), desc=f'test round', unit='img', leave=False) as pbar:\n",
    "        for imgs, labels in val_loader:\n",
    "            imgs, labels = imgs.to(device), labels.to(device)\n",
    "            preds = net(imgs)\n",
    "            tot_loss += F.cross_entropy(preds, labels).item() * labels.size(0)\n",
    "            pred_idx = torch.softmax(preds, dim=1)\n",
    "            pred_idx += torch.tensor([-0.1, -0.1, 0.2], dtype=torch.float32, device=device)\n",
    "            pred_ori_list += pred_idx.tolist()\n",
    "            pred_idx = pred_idx.argmax(dim=1)\n",
    "            labels_list = labels.tolist()\n",
    "            true_list += labels_list\n",
    "            pred_idx = pred_idx.tolist()\n",
    "            pred_list.extend(pred_idx)\n",
    "            pbar.update(labels.size(0))\n",
    "\n",
    "AP = []\n",
    "for c in range(3):\n",
    "    c_true_list = [int(item==c) for item in true_list]\n",
    "    c_pred_ori_list = [item[c] for item in pred_ori_list]\n",
    "    AP.append(metrics.average_precision_score(c_true_list, c_pred_ori_list))\n",
    "\n",
    "print(f'report:\\n'+metrics.classification_report(true_list, pred_list, digits=4))\n",
    "\n",
    "print(float(np.mean(AP)))\n",
    "print(tot_loss/len(val_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fb7fee9",
   "metadata": {},
   "source": [
    "# 7查看IMH分类错误的样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "760ab894",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import yaml\n",
    "from types import SimpleNamespace\n",
    "from collections import OrderedDict\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from torchvision.datasets import ImageFolder\n",
    "from torch.utils.data import DataLoader, WeightedRandomSampler\n",
    "from torch import optim\n",
    "import torchvision.transforms as T\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from data.datasets import get_weight_list\n",
    "from utils.ranger import Ranger\n",
    "from utils.lr_scheduler import CosineAnnealingWithWarmUpLR\n",
    "from model.SupCon import *\n",
    "import data.transforms as MT\n",
    "\n",
    "from torch.utils.data import Dataset\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3138a9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestDataset(Dataset):\n",
    "    def __init__(self, root, transform):\n",
    "        self.transform = transform\n",
    "        self.datas = []\n",
    "        for img in os.listdir(root):\n",
    "            self.datas.append(os.path.join(root, img))\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.datas[index])\n",
    "        img = self.transform(img)\n",
    "        return img, self.datas[index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84667e95",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:2')\n",
    "load_model = '.details/checkpoints/MD/09-26_03:06:49_最优权重正梯度,去掉var_loss/Net_best.pth'\n",
    "net = resnet(34, n_channels=2, n_classes=3)\n",
    "net.load_state_dict(torch.load(load_model, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de4ac7ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_size = 81\n",
    "vt_list = [\n",
    "    T.Resize(image_size),\n",
    "    T.CenterCrop(image_size),\n",
    "    T.ToTensor(),\n",
    "    MT.SobelChannel(3)\n",
    "]\n",
    "dataset = TestDataset('/nfs3-p2/zsxm/dataset/aorta_classify_ct_-100_500/test/2', T.Compose(vt_list))\n",
    "dataloader = DataLoader(dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97854103",
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_path = '/nfs3-p1/zsxm/dataset/zct_imh_error/zct_imh_test1/1'\n",
    "neg_path = '/nfs3-p1/zsxm/dataset/zct_imh_error/zct_imh_test1/0'\n",
    "os.makedirs(pos_path, exist_ok=False)\n",
    "os.makedirs(neg_path, exist_ok=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35ad9c63",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    for imgs, paths in tqdm(dataloader):\n",
    "        imgs = imgs.to(device)\n",
    "        preds = torch.softmax(net(imgs), dim=1)\n",
    "        for pred, path in zip(preds, paths):\n",
    "            score = pred[2].item()\n",
    "            neg_score = pred[0].item()\n",
    "            ad_score = pred[1].item()\n",
    "            basename = os.path.basename(path)\n",
    "            if score > 0.5:\n",
    "                shutil.copy(path, os.path.join(pos_path, f'{score:.5f}_{neg_score:.5f}_{ad_score:.5f}_{basename}'))\n",
    "            else:\n",
    "                shutil.copy(path, os.path.join(neg_path, f'{score:.5f}_{neg_score:.5f}_{ad_score:.5f}_{basename}'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f7ef071",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "08185ecc",
   "metadata": {},
   "source": [
    "# 8.TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56b8a731",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "from types import SimpleNamespace\n",
    "from collections import OrderedDict\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from PIL import Image\n",
    "from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset\n",
    "from torch import optim\n",
    "import torchvision.transforms as T\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from data.datasets import AortaDataset3DCenter, get_weight_list, AortaDataset\n",
    "from utils.ranger import Ranger\n",
    "from utils.lr_scheduler import CosineAnnealingWithWarmUpLR\n",
    "from model.resnet3d import resnet3d\n",
    "from model.SupCon import resnet\n",
    "import data.transforms as MT\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from MulticoreTSNE import MulticoreTSNE as TSNE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "492801bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#3D模型\n",
    "device = torch.device('cuda:2')\n",
    "load_model = '.details/checkpoints/MD3D/08-28_23:12:08_3.2.最优权重.正梯度/Net_best.pth'\n",
    "net = resnet3d(34, n_channels=2, n_classes=3, conv1_t_size=3, with_fea=True)\n",
    "net.load_state_dict(torch.load(load_model, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()\n",
    "\n",
    "t_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "transform = T.Compose(t_list)\n",
    "\n",
    "train_dataset = AortaDataset3DCenter('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/train', [0,1,2], transform, depth=7, step=1)\n",
    "train_loader = DataLoader(train_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)\n",
    "\n",
    "val_dataset = AortaDataset3DCenter('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/val', [0,1,2], transform, depth=7, step=1)\n",
    "val_loader = DataLoader(val_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)\n",
    "\n",
    "test_dataset = AortaDataset3DCenter('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/center/test', [0,1,2], transform, depth=7, step=1)\n",
    "test_loader = DataLoader(test_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9fd9222",
   "metadata": {},
   "outputs": [],
   "source": [
    "#单张2D模型\n",
    "device = torch.device('cuda:2')\n",
    "n_classes = 3\n",
    "load_model = '.details/checkpoints/MD/09-26_03:06:49_最优权重正梯度,去掉var_loss/Net_best.pth'\n",
    "net = resnet(34, n_channels=2, n_classes=n_classes, with_fea=True)\n",
    "net.load_state_dict(torch.load(load_model, map_location=device))\n",
    "net.to(device)\n",
    "net.eval()\n",
    "\n",
    "t_list = [\n",
    "    T.Resize(81),\n",
    "    T.CenterCrop(81),\n",
    "    T.ToTensor(),\n",
    "    MT.SobelChannel(3)\n",
    "]\n",
    "transform = T.Compose(t_list)\n",
    "\n",
    "train_dataset = AortaDataset('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/train', [0,1,2], transform=transform)\n",
    "train_loader = DataLoader(train_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)\n",
    "\n",
    "val_dataset = AortaDataset('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/val', [0,1,2], transform=transform)\n",
    "val_loader = DataLoader(val_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)\n",
    "\n",
    "test_dataset = AortaDataset('/nfs3-p1/zsxm/dataset/aorta_classify_ct_-100_500/test', [0,1,2], transform=transform)\n",
    "test_loader = DataLoader(test_dataset,\n",
    "                        batch_size=128,\n",
    "                        shuffle=False,\n",
    "                        drop_last=False,\n",
    "                        num_workers=8, \n",
    "                        pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d68c7f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_labels = []\n",
    "fea_list = []\n",
    "for imgs, labels in tqdm(train_loader):\n",
    "    imgs = imgs.to(device)\n",
    "    true_labels.extend(labels.tolist())\n",
    "    with torch.no_grad():\n",
    "        _, feas = net(imgs)\n",
    "    fea_list.append(feas.cpu())\n",
    "for imgs, labels in tqdm(val_loader):\n",
    "    imgs = imgs.to(device)\n",
    "    true_labels.extend((labels+n_classes).tolist())\n",
    "    with torch.no_grad():\n",
    "        _, feas = net(imgs)\n",
    "    fea_list.append(feas.cpu())\n",
    "for imgs, labels in tqdm(test_loader):\n",
    "    imgs = imgs.to(device)\n",
    "    true_labels.extend((labels+n_classes*2).tolist())\n",
    "    with torch.no_grad():\n",
    "        _, feas = net(imgs)\n",
    "    fea_list.append(feas.cpu())\n",
    "true_labels = np.array(true_labels)\n",
    "fea_list = torch.cat(fea_list, dim=0).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e2ddec5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(true_labels.shape)\n",
    "print(fea_list.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6350ad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "feas_tsne = TSNE(n_jobs=8).fit_transform(fea_list)\n",
    "# vis_x = feas_tsne[:, 0]\n",
    "# vis_y = feas_tsne[:, 1]\n",
    "# plt.scatter(vis_x, vis_y, c=true_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes*3), marker='.')\n",
    "# plt.colorbar(ticks=range(n_classes*3))\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddff8212",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d257b0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]-n_classes\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29c43ecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]-n_classes*2\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "067c9f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "935375ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]-n_classes\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a04222c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]-n_classes*2\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20dbf127",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==0,true_labels==1),true_labels==2)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44d046bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==3,true_labels==4),true_labels==5)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]-n_classes\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fc1e2e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_x = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8), 0]\n",
    "vis_y = feas_tsne[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8), 1]\n",
    "vis_labels = true_labels[np.bitwise_or(np.bitwise_or(true_labels==6,true_labels==7),true_labels==8)]\n",
    "ran_per = np.random.permutation(vis_labels.shape[0])\n",
    "vis_x = vis_x[ran_per]\n",
    "vis_y = vis_y[ran_per]\n",
    "vis_labels = vis_labels[ran_per]-n_classes*2\n",
    "plt.scatter(vis_x, vis_y, c=vis_labels, cmap=plt.cm.get_cmap(\"jet\", n_classes), marker='.')\n",
    "plt.colorbar(ticks=range(n_classes))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "463abfcb",
   "metadata": {},
   "source": [
    "# 9.Stanford分型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d36976a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "import copy\n",
    "import  traceback\n",
    "import shutil\n",
    "import random\n",
    "import numpy as np  # linear algebra\n",
    "import pydicom\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "from pydicom.uid import UID\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "import openpyxl\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset\n",
    "import torchvision.transforms as T\n",
    "from sklearn import metrics\n",
    "import matplotlib.pyplot as plt\n",
    "from data.datasets import AortaDataset3DCenter\n",
    "from model.resnet3d import resnet3d\n",
    "import data.transforms as MT\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bfd55e66",
   "metadata": {},
   "outputs": [],
   "source": [
    "#数据集\n",
    "class PatientDataset(Dataset):\n",
    "    def __init__(self, root, transform):\n",
    "        self.datas = []\n",
    "        self.transform = transform\n",
    "        for i in ['0', '1']:\n",
    "            cpath = os.path.join(root, i)\n",
    "            i = int(i)\n",
    "            for patient in sorted(os.listdir(cpath)):\n",
    "                ppath = os.path.join(cpath, patient)\n",
    "                imgs = os.listdir(ppath)\n",
    "                pd = {}\n",
    "                pd['j'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('j_'), imgs))))\n",
    "                pd['s'] = list(map(lambda x: os.path.join(ppath, x), sorted(filter(lambda x: x.startswith('s_'), imgs))))\n",
    "                pd['j'] = [pd['j'][0]]*3 + pd['j'] + [pd['j'][-1]]*3\n",
    "                pd['s'] = [pd['s'][0]]*3 + pd['s'] + [pd['s'][-1]]*3\n",
    "                self.datas.append([patient, pd, i])\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        patient, pd, label = self.datas[index]\n",
    "        jimgs = list(map(lambda x: Image.open(x), pd['j']))\n",
    "        simgs = list(map(lambda x: Image.open(x), pd['s']))\n",
    "        jimgs = self.transform(jimgs)\n",
    "        simgs = self.transform(simgs)\n",
    "        return jimgs, simgs, label, patient\n",
    "\n",
    "trans_list = [\n",
    "    MT.Resize3D(81),\n",
    "    MT.CenterCrop3D(81),\n",
    "    MT.ToTensor3D(),\n",
    "    MT.SobelChannel(3, flag_3d=True)\n",
    "]\n",
    "transform = T.Compose(trans_list)\n",
    "dataset = PatientDataset('/medical-data/zsxm/AorticDissection/adpaper/3scan/ct/test', transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6b094019",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84\n",
      "laiguizhen-J-Im38-121 1\n",
      "wangweibo-J-Im36-149 1\n",
      "yuguoping-J-19-35 1\n"
     ]
    }
   ],
   "source": [
    "#获取每个病人的stanford分型\n",
    "folderdict = {\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-07-30':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-09-08':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-09-13':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-09-17-negative':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-09-19':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-09-28':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-09-29-negative':list(),\n",
    "    '/medical-data/zsxm/AorticDissection/dataset/2021-11-20':list(), \n",
    "}\n",
    "for k, v in folderdict.items():\n",
    "    v.extend(sorted([f for f in os.listdir(k) if os.path.isdir(os.path.join(k,f))]))\n",
    "\n",
    "patient_set = set()\n",
    "for patient, _, _ in dataset.datas:\n",
    "    patient_set.add(patient)\n",
    "patient_list = sorted(list(patient_set))\n",
    "    \n",
    "stanford = {}\n",
    "for patient in patient_list: \n",
    "    folder_list = []\n",
    "    for k, v in folderdict.items():\n",
    "        if patient in v:\n",
    "            folder_list.append(k)\n",
    "    assert len(folder_list) == 1, f'{patient}:{folder_list}'\n",
    "    folder = folder_list[0]\n",
    "    if 'negative' in folder:\n",
    "        stanford[patient] = 0\n",
    "    else:\n",
    "        wb = openpyxl.load_workbook(os.path.join(folder, 'label.xlsx'))\n",
    "        st = wb['Sheet1']\n",
    "        for row in st.iter_rows():\n",
    "            if row[0].value is not None and row[0].value == patient.split('-')[0]:\n",
    "                stan = row[5].value.lower()\n",
    "                if 'a' in stan:\n",
    "                    stanford[patient] = 1\n",
    "                elif 'b' in stan:\n",
    "                    stanford[patient] = 2\n",
    "                else:\n",
    "                    raise ValueError(f'{patient}: {stan}')\n",
    "                break\n",
    "        else:\n",
    "            raise ValueError(f\"Can't find {patient} in {os.path.join(folder, 'label.xlsx')}\")\n",
    "print(len(stanford))\n",
    "for k, v in stanford.items():\n",
    "    if v == 1 and 's' not in k.lower():\n",
    "        print(k,v)\n",
    "        stanford[k] = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba87b26c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "82d353c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1/1: 100%|████████| 84/84 [00:23<00:00,  3.56it/s]\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     1.0000    0.2826    0.4407        46\n",
      "           1     0.2051    0.5714    0.3019        14\n",
      "           2     0.2500    0.3333    0.2857        24\n",
      "\n",
      "    accuracy                         0.3452        84\n",
      "   macro avg     0.4850    0.3958    0.3428        84\n",
      "weighted avg     0.6532    0.3452    0.3733        84\n",
      "\n",
      "2 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     1.0000    0.4348    0.6061        46\n",
      "           1     0.2667    0.5714    0.3636        14\n",
      "           2     0.3529    0.5000    0.4138        24\n",
      "\n",
      "    accuracy                         0.4762        84\n",
      "   macro avg     0.5399    0.5021    0.4612        84\n",
      "weighted avg     0.6929    0.4762    0.5107        84\n",
      "\n",
      "3 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     1.0000    0.4565    0.6269        46\n",
      "           1     0.2593    0.5000    0.3415        14\n",
      "           2     0.3333    0.5000    0.4000        24\n",
      "\n",
      "    accuracy                         0.4762        84\n",
      "   macro avg     0.5309    0.4855    0.4561        84\n",
      "weighted avg     0.6861    0.4762    0.5145        84\n",
      "\n",
      "4 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9667    0.6304    0.7632        46\n",
      "           1     0.3333    0.5000    0.4000        14\n",
      "           2     0.4242    0.5833    0.4912        24\n",
      "\n",
      "    accuracy                         0.5952        84\n",
      "   macro avg     0.5747    0.5713    0.5515        84\n",
      "weighted avg     0.7061    0.5952    0.6249        84\n",
      "\n",
      "5 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9706    0.7174    0.8250        46\n",
      "           1     0.3333    0.4286    0.3750        14\n",
      "           2     0.4375    0.5833    0.5000        24\n",
      "\n",
      "    accuracy                         0.6310        84\n",
      "   macro avg     0.5805    0.5764    0.5667        84\n",
      "weighted avg     0.7121    0.6310    0.6571        84\n",
      "\n",
      "6 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9500    0.8261    0.8837        46\n",
      "           1     0.3333    0.3571    0.3448        14\n",
      "           2     0.5172    0.6250    0.5660        24\n",
      "\n",
      "    accuracy                         0.6905        84\n",
      "   macro avg     0.6002    0.6027    0.5982        84\n",
      "weighted avg     0.7236    0.6905    0.7031        84\n",
      "\n",
      "7 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9524    0.8696    0.9091        46\n",
      "           1     0.4167    0.3571    0.3846        14\n",
      "           2     0.5667    0.7083    0.6296        24\n",
      "\n",
      "    accuracy                         0.7381        84\n",
      "   macro avg     0.6452    0.6450    0.6411        84\n",
      "weighted avg     0.7529    0.7381    0.7418        84\n",
      "\n",
      "8 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9524    0.8696    0.9091        46\n",
      "           1     0.4444    0.2857    0.3478        14\n",
      "           2     0.5455    0.7500    0.6316        24\n",
      "\n",
      "    accuracy                         0.7381        84\n",
      "   macro avg     0.6474    0.6351    0.6295        84\n",
      "weighted avg     0.7515    0.7381    0.7363        84\n",
      "\n",
      "9 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9524    0.8696    0.9091        46\n",
      "           1     0.5000    0.2857    0.3636        14\n",
      "           2     0.5294    0.7500    0.6207        24\n",
      "\n",
      "    accuracy                         0.7381        84\n",
      "   macro avg     0.6606    0.6351    0.6311        84\n",
      "weighted avg     0.7561    0.7381    0.7358        84\n",
      "\n",
      "10 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9302    0.8696    0.8989        46\n",
      "           1     0.5000    0.2857    0.3636        14\n",
      "           2     0.5152    0.7083    0.5965        24\n",
      "\n",
      "    accuracy                         0.7262        84\n",
      "   macro avg     0.6485    0.6212    0.6197        84\n",
      "weighted avg     0.7399    0.7262    0.7233        84\n",
      "\n",
      "11 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9091    0.8696    0.8889        46\n",
      "           1     0.5000    0.2857    0.3636        14\n",
      "           2     0.5000    0.6667    0.5714        24\n",
      "\n",
      "    accuracy                         0.7143        84\n",
      "   macro avg     0.6364    0.6073    0.6080        84\n",
      "weighted avg     0.7240    0.7143    0.7106        84\n",
      "\n",
      "12 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8889    0.8696    0.8791        46\n",
      "           1     0.4286    0.2143    0.2857        14\n",
      "           2     0.5000    0.6667    0.5714        24\n",
      "\n",
      "    accuracy                         0.7024        84\n",
      "   macro avg     0.6058    0.5835    0.5788        84\n",
      "weighted avg     0.7011    0.7024    0.6923        84\n",
      "\n",
      "13 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8936    0.9130    0.9032        46\n",
      "           1     0.6000    0.2143    0.3158        14\n",
      "           2     0.5625    0.7500    0.6429        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.6854    0.6258    0.6206        84\n",
      "weighted avg     0.7501    0.7500    0.7309        84\n",
      "\n",
      "14 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8958    0.9348    0.9149        46\n",
      "           1     0.6000    0.2143    0.3158        14\n",
      "           2     0.5806    0.7500    0.6545        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.6922    0.6330    0.6284        84\n",
      "weighted avg     0.7565    0.7619    0.7407        84\n",
      "\n",
      "15 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8776    0.9348    0.9053        46\n",
      "           1     0.6667    0.1429    0.2353        14\n",
      "           2     0.5938    0.7917    0.6786        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.7127    0.6231    0.6064        84\n",
      "weighted avg     0.7613    0.7619    0.7288        84\n",
      "\n",
      "16 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     0.6667    0.1429    0.2353        14\n",
      "           2     0.6000    0.7500    0.6667        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.7033    0.6092    0.5962        84\n",
      "weighted avg     0.7443    0.7500    0.7152        84\n",
      "\n",
      "17 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     0.6667    0.1429    0.2353        14\n",
      "           2     0.6000    0.7500    0.6667        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.7033    0.6092    0.5962        84\n",
      "weighted avg     0.7443    0.7500    0.7152        84\n",
      "\n",
      "18 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.1429    0.2500        14\n",
      "           2     0.6129    0.7917    0.6909        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8187    0.6231    0.6092        84\n",
      "weighted avg     0.8035    0.7619    0.7246        84\n",
      "\n",
      "19 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.5938    0.7917    0.6786        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.8123    0.5993    0.5662        84\n",
      "weighted avg     0.7980    0.7500    0.7016        84\n",
      "\n",
      "20 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.5938    0.7917    0.6786        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.8123    0.5993    0.5662        84\n",
      "weighted avg     0.7980    0.7500    0.7016        84\n",
      "\n",
      "21 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.5938    0.7917    0.6786        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.8123    0.5993    0.5662        84\n",
      "weighted avg     0.7980    0.7500    0.7016        84\n",
      "\n",
      "22 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.5938    0.7917    0.6786        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.8123    0.5993    0.5662        84\n",
      "weighted avg     0.7980    0.7500    0.7016        84\n",
      "\n",
      "23 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.5938    0.7917    0.6786        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.8123    0.5993    0.5662        84\n",
      "weighted avg     0.7980    0.7500    0.7016        84\n",
      "\n",
      "24 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.6129    0.7917    0.6909        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8197    0.6065    0.5741        84\n",
      "weighted avg     0.8052    0.7619    0.7114        84\n",
      "\n",
      "25 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.6129    0.7917    0.6909        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8197    0.6065    0.5741        84\n",
      "weighted avg     0.8052    0.7619    0.7114        84\n",
      "\n",
      "26 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.6129    0.7917    0.6909        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8197    0.6065    0.5741        84\n",
      "weighted avg     0.8052    0.7619    0.7114        84\n",
      "\n",
      "27 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.6129    0.7917    0.6909        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8197    0.6065    0.5741        84\n",
      "weighted avg     0.8052    0.7619    0.7114        84\n",
      "\n",
      "28 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8302    0.9565    0.8889        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.6333    0.7917    0.7037        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8212    0.6065    0.5753        84\n",
      "weighted avg     0.8022    0.7619    0.7101        84\n",
      "\n",
      "29 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8182    0.9783    0.8911        46\n",
      "           1     1.0000    0.0714    0.1333        14\n",
      "           2     0.6429    0.7500    0.6923        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.8203    0.5999    0.5722        84\n",
      "weighted avg     0.7984    0.7619    0.7080        84\n",
      "\n",
      "30 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8182    0.9783    0.8911        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6207    0.7500    0.6792        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.4796    0.5761    0.5234        84\n",
      "weighted avg     0.6254    0.7500    0.6820        84\n",
      "\n",
      "31 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8182    0.9783    0.8911        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6207    0.7500    0.6792        24\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.4796    0.5761    0.5234        84\n",
      "weighted avg     0.6254    0.7500    0.6820        84\n",
      "\n",
      "32 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8070    1.0000    0.8932        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6667    0.7500    0.7059        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4912    0.5833    0.5330        84\n",
      "weighted avg     0.6324    0.7619    0.6908        84\n",
      "\n",
      "33 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8070    1.0000    0.8932        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6667    0.7500    0.7059        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4912    0.5833    0.5330        84\n",
      "weighted avg     0.6324    0.7619    0.6908        84\n",
      "\n",
      "34 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8070    1.0000    0.8932        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6667    0.7500    0.7059        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4912    0.5833    0.5330        84\n",
      "weighted avg     0.6324    0.7619    0.6908        84\n",
      "\n",
      "35 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6923    0.7500    0.7200        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4951    0.5833    0.5349        84\n",
      "weighted avg     0.6321    0.7619    0.6901        84\n",
      "\n",
      "36 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6923    0.7500    0.7200        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4951    0.5833    0.5349        84\n",
      "weighted avg     0.6321    0.7619    0.6901        84\n",
      "\n",
      "37 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6923    0.7500    0.7200        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4951    0.5833    0.5349        84\n",
      "weighted avg     0.6321    0.7619    0.6901        84\n",
      "\n",
      "38 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6923    0.7500    0.7200        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4951    0.5833    0.5349        84\n",
      "weighted avg     0.6321    0.7619    0.6901        84\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6923    0.7500    0.7200        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4951    0.5833    0.5349        84\n",
      "weighted avg     0.6321    0.7619    0.6901        84\n",
      "\n",
      "40 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        14\n",
      "           2     0.6923    0.7500    0.7200        24\n",
      "\n",
      "    accuracy                         0.7619        84\n",
      "   macro avg     0.4951    0.5833    0.5349        84\n",
      "weighted avg     0.6321    0.7619    0.6901        84\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "#测试Stanford分型结果\n",
    "weight_list = [\n",
    "    '.details151/CE3D/09-12_174850_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/09-12_180329_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/09-12_180348_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/12-05_183652_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/12-05_185011_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "]\n",
    "N_MAX = 40\n",
    "device = torch.device('cuda:2')\n",
    "exp_auc_list, exp_acc_list, exp_sen_list, exp_spe_list = [], [], [], []\n",
    "for iw, weight in enumerate(weight_list):\n",
    "    net = resnet3d(34, n_channels=2, n_classes=2, conv1_t_size=3)\n",
    "    net.load_state_dict(torch.load(weight, map_location=device))\n",
    "    net.to(device)\n",
    "    net.eval()\n",
    "    \n",
    "    true_list = []\n",
    "    sjpred_list = []\n",
    "    with torch.no_grad():\n",
    "        for jimgs, simgs, label, patient in tqdm(dataset, desc=f'{iw+1}/{len(weight_list)}', leave=True, ncols=50):\n",
    "            true_list.append(stanford[patient])\n",
    "            simgs = [torch.stack(simgs[i:i+7], dim=1) for i in range(len(simgs)-6)]\n",
    "            jimgs = [torch.stack(jimgs[i:i+7], dim=1) for i in range(len(jimgs)-6)]\n",
    "            imgs = torch.stack(simgs+jimgs, dim=0).to(device)\n",
    "            preds = torch.argmax(net(imgs),dim=1).tolist()\n",
    "            spreds, jpreds = preds[:len(simgs)], preds[len(simgs):]\n",
    "            sjpred_list.append( (spreds, jpreds) )\n",
    "    true_list = np.array(true_list)\n",
    "    for N in range(1, N_MAX+1):\n",
    "        pred_list = []\n",
    "        for spreds, jpreds in sjpred_list:\n",
    "            patient_label = 0\n",
    "            for isj, preds in enumerate([spreds, jpreds]):\n",
    "                pred_label = 0\n",
    "                pre = -1\n",
    "                max_len = -1\n",
    "                start_idx = -1\n",
    "                for i in range(len(preds)):\n",
    "                    cur = preds[i]\n",
    "                    if cur != pre:\n",
    "                        if pre in [1, 2]:\n",
    "                            ln = i - start_idx\n",
    "                            if ln >= N and ln > max_len:\n",
    "                                max_len = ln\n",
    "                                pred_label = pre\n",
    "                        start_idx = i\n",
    "                    pre = cur\n",
    "                if pred_label != 0:\n",
    "                    if isj == 0:\n",
    "                        patient_label = 1\n",
    "                        break\n",
    "                    else:\n",
    "                        patient_label = 2\n",
    "            pred_list.append(patient_label)\n",
    "        pred_list = np.array(pred_list)\n",
    "#         true_list[true_list == 2] =1\n",
    "#         pred_list[pred_list == 2] =1\n",
    "        print(f'{N} stanford report:\\n'+metrics.classification_report(true_list, pred_list, digits=4))\n",
    "    del net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3e4e4df8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1/1: 100%|████████| 84/84 [00:29<00:00,  2.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     1.0000    0.2826    0.4407        46\n",
      "           1     0.1538    0.5455    0.2400        11\n",
      "           2     0.2812    0.3333    0.3051        27\n",
      "\n",
      "    accuracy                         0.3333        84\n",
      "   macro avg     0.4784    0.3871    0.3286        84\n",
      "weighted avg     0.6582    0.3333    0.3708        84\n",
      "\n",
      "2 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     1.0000    0.4348    0.6061        46\n",
      "           1     0.2000    0.5455    0.2927        11\n",
      "           2     0.3824    0.4815    0.4262        27\n",
      "\n",
      "    accuracy                         0.4643        84\n",
      "   macro avg     0.5275    0.4872    0.4417        84\n",
      "weighted avg     0.6967    0.4643    0.5072        84\n",
      "\n",
      "3 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     1.0000    0.4565    0.6269        46\n",
      "           1     0.1852    0.4545    0.2632        11\n",
      "           2     0.3611    0.4815    0.4127        27\n",
      "\n",
      "    accuracy                         0.4643        84\n",
      "   macro avg     0.5154    0.4642    0.4342        84\n",
      "weighted avg     0.6879    0.4643    0.5104        84\n",
      "\n",
      "4 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9667    0.6304    0.7632        46\n",
      "           1     0.2381    0.4545    0.3125        11\n",
      "           2     0.4545    0.5556    0.5000        27\n",
      "\n",
      "    accuracy                         0.5833        84\n",
      "   macro avg     0.5531    0.5468    0.5252        84\n",
      "weighted avg     0.7066    0.5833    0.6196        84\n",
      "\n",
      "5 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9706    0.7174    0.8250        46\n",
      "           1     0.2778    0.4545    0.3448        11\n",
      "           2     0.5000    0.5926    0.5424        27\n",
      "\n",
      "    accuracy                         0.6429        84\n",
      "   macro avg     0.5828    0.5882    0.5707        84\n",
      "weighted avg     0.7286    0.6429    0.6713        84\n",
      "\n",
      "6 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9500    0.8261    0.8837        46\n",
      "           1     0.2667    0.3636    0.3077        11\n",
      "           2     0.5862    0.6296    0.6071        27\n",
      "\n",
      "    accuracy                         0.7024        84\n",
      "   macro avg     0.6010    0.6065    0.5995        84\n",
      "weighted avg     0.7436    0.7024    0.7194        84\n",
      "\n",
      "7 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9524    0.8696    0.9091        46\n",
      "           1     0.3333    0.3636    0.3478        11\n",
      "           2     0.6333    0.7037    0.6667        27\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.6397    0.6456    0.6412        84\n",
      "weighted avg     0.7688    0.7500    0.7577        84\n",
      "\n",
      "8 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9524    0.8696    0.9091        46\n",
      "           1     0.3333    0.2727    0.3000        11\n",
      "           2     0.6061    0.7407    0.6667        27\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.6306    0.6277    0.6253        84\n",
      "weighted avg     0.7600    0.7500    0.7514        84\n",
      "\n",
      "9 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9524    0.8696    0.9091        46\n",
      "           1     0.3750    0.2727    0.3158        11\n",
      "           2     0.5882    0.7407    0.6557        27\n",
      "\n",
      "    accuracy                         0.7500        84\n",
      "   macro avg     0.6385    0.6277    0.6269        84\n",
      "weighted avg     0.7597    0.7500    0.7500        84\n",
      "\n",
      "10 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9302    0.8696    0.8989        46\n",
      "           1     0.3750    0.2727    0.3158        11\n",
      "           2     0.5758    0.7037    0.6333        27\n",
      "\n",
      "    accuracy                         0.7381        84\n",
      "   macro avg     0.6270    0.6153    0.6160        84\n",
      "weighted avg     0.7436    0.7381    0.7372        84\n",
      "\n",
      "11 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9091    0.8696    0.8889        46\n",
      "           1     0.3750    0.2727    0.3158        11\n",
      "           2     0.5625    0.6667    0.6102        27\n",
      "\n",
      "    accuracy                         0.7262        84\n",
      "   macro avg     0.6155    0.6030    0.6049        84\n",
      "weighted avg     0.7277    0.7262    0.7243        84\n",
      "\n",
      "12 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8889    0.8696    0.8791        46\n",
      "           1     0.4286    0.2727    0.3333        11\n",
      "           2     0.5625    0.6667    0.6102        27\n",
      "\n",
      "    accuracy                         0.7262        84\n",
      "   macro avg     0.6267    0.6030    0.6075        84\n",
      "weighted avg     0.7237    0.7262    0.7212        84\n",
      "\n",
      "13 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8936    0.9130    0.9032        46\n",
      "           1     0.6000    0.2727    0.3750        11\n",
      "           2     0.6250    0.7407    0.6780        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.7062    0.6422    0.6521        84\n",
      "weighted avg     0.7688    0.7738    0.7616        84\n",
      "\n",
      "14 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8958    0.9348    0.9149        46\n",
      "           1     0.6000    0.2727    0.3750        11\n",
      "           2     0.6452    0.7407    0.6897        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.7137    0.6494    0.6598        84\n",
      "weighted avg     0.7765    0.7857    0.7718        84\n",
      "\n",
      "15 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8776    0.9348    0.9053        46\n",
      "           1     0.6667    0.1818    0.2857        11\n",
      "           2     0.6562    0.7778    0.7119        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.7335    0.6315    0.6343        84\n",
      "weighted avg     0.7788    0.7857    0.7620        84\n",
      "\n",
      "16 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     0.6667    0.1818    0.2857        11\n",
      "           2     0.6667    0.7407    0.7018        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.7255    0.6191    0.6247        84\n",
      "weighted avg     0.7633    0.7738    0.7485        84\n",
      "\n",
      "17 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     0.6667    0.1818    0.2857        11\n",
      "           2     0.6667    0.7407    0.7018        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.7255    0.6191    0.6247        84\n",
      "weighted avg     0.7633    0.7738    0.7485        84\n",
      "\n",
      "18 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.1818    0.3077        11\n",
      "           2     0.6774    0.7778    0.7241        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8402    0.6315    0.6395        84\n",
      "weighted avg     0.8104    0.7857    0.7586        84\n",
      "\n",
      "19 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6562    0.7778    0.7119        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.8331    0.6012    0.5884        84\n",
      "weighted avg     0.8036    0.7738    0.7362        84\n",
      "\n",
      "20 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6562    0.7778    0.7119        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.8331    0.6012    0.5884        84\n",
      "weighted avg     0.8036    0.7738    0.7362        84\n",
      "\n",
      "21 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6562    0.7778    0.7119        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.8331    0.6012    0.5884        84\n",
      "weighted avg     0.8036    0.7738    0.7362        84\n",
      "\n",
      "22 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6562    0.7778    0.7119        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.8331    0.6012    0.5884        84\n",
      "weighted avg     0.8036    0.7738    0.7362        84\n",
      "\n",
      "23 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8431    0.9348    0.8866        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6562    0.7778    0.7119        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.8331    0.6012    0.5884        84\n",
      "weighted avg     0.8036    0.7738    0.7362        84\n",
      "\n",
      "24 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6774    0.7778    0.7241        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8412    0.6084    0.5963        84\n",
      "weighted avg     0.8121    0.7857    0.7463        84\n",
      "\n",
      "25 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6774    0.7778    0.7241        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8412    0.6084    0.5963        84\n",
      "weighted avg     0.8121    0.7857    0.7463        84\n",
      "\n",
      "26 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6774    0.7778    0.7241        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8412    0.6084    0.5963        84\n",
      "weighted avg     0.8121    0.7857    0.7463        84\n",
      "\n",
      "27 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8462    0.9565    0.8980        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.6774    0.7778    0.7241        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8412    0.6084    0.5963        84\n",
      "weighted avg     0.8121    0.7857    0.7463        84\n",
      "\n",
      "28 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8302    0.9565    0.8889        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.7000    0.7778    0.7368        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8434    0.6084    0.5975        84\n",
      "weighted avg     0.8106    0.7857    0.7454        84\n",
      "\n",
      "29 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8182    0.9783    0.8911        46\n",
      "           1     1.0000    0.0909    0.1667        11\n",
      "           2     0.7143    0.7407    0.7273        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.8442    0.6033    0.5950        84\n",
      "weighted avg     0.8086    0.7857    0.7436        84\n",
      "\n",
      "30 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8182    0.9783    0.8911        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.6897    0.7407    0.7143        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.5026    0.5730    0.5351        84\n",
      "weighted avg     0.6697    0.7738    0.7176        84\n",
      "\n",
      "31 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8182    0.9783    0.8911        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.6897    0.7407    0.7143        27\n",
      "\n",
      "    accuracy                         0.7738        84\n",
      "   macro avg     0.5026    0.5730    0.5351        84\n",
      "weighted avg     0.6697    0.7738    0.7176        84\n",
      "\n",
      "32 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8070    1.0000    0.8932        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7407    0.7407    0.7407        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5159    0.5802    0.5446        84\n",
      "weighted avg     0.6800    0.7857    0.7272        84\n",
      "\n",
      "33 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8070    1.0000    0.8932        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7407    0.7407    0.7407        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5159    0.5802    0.5446        84\n",
      "weighted avg     0.6800    0.7857    0.7272        84\n",
      "\n",
      "34 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.8070    1.0000    0.8932        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7407    0.7407    0.7407        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5159    0.5802    0.5446        84\n",
      "weighted avg     0.6800    0.7857    0.7272        84\n",
      "\n",
      "35 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7692    0.7407    0.7547        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5208    0.5802    0.5464        84\n",
      "weighted avg     0.6816    0.7857    0.7270        84\n",
      "\n",
      "36 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7692    0.7407    0.7547        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5208    0.5802    0.5464        84\n",
      "weighted avg     0.6816    0.7857    0.7270        84\n",
      "\n",
      "37 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7692    0.7407    0.7547        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5208    0.5802    0.5464        84\n",
      "weighted avg     0.6816    0.7857    0.7270        84\n",
      "\n",
      "38 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7692    0.7407    0.7547        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5208    0.5802    0.5464        84\n",
      "weighted avg     0.6816    0.7857    0.7270        84\n",
      "\n",
      "39 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7692    0.7407    0.7547        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5208    0.5802    0.5464        84\n",
      "weighted avg     0.6816    0.7857    0.7270        84\n",
      "\n",
      "40 stanford report:\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.7931    1.0000    0.8846        46\n",
      "           1     0.0000    0.0000    0.0000        11\n",
      "           2     0.7692    0.7407    0.7547        27\n",
      "\n",
      "    accuracy                         0.7857        84\n",
      "   macro avg     0.5208    0.5802    0.5464        84\n",
      "weighted avg     0.6816    0.7857    0.7270        84\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "/home/zsxm/miniconda3/envs/pytorch/lib/python3.7/site-packages/sklearn/metrics/_classification.py:1318: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "#测试Stanford分型结果\n",
    "weight_list = [\n",
    "    '.details151/CE3D/09-12_174850_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/09-12_180329_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/09-12_180348_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/12-05_183652_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "#     '.details151/CE3D/12-05_185011_3D,有Sobel,只分类阴性和夹层/Net_best.pth',\n",
    "]\n",
    "N_MAX = 40\n",
    "device = torch.device('cuda:2')\n",
    "exp_auc_list, exp_acc_list, exp_sen_list, exp_spe_list = [], [], [], []\n",
    "for iw, weight in enumerate(weight_list):\n",
    "    net = resnet3d(34, n_channels=2, n_classes=2, conv1_t_size=3)\n",
    "    net.load_state_dict(torch.load(weight, map_location=device))\n",
    "    net.to(device)\n",
    "    net.eval()\n",
    "    \n",
    "    true_list = []\n",
    "    sjpred_list = []\n",
    "    with torch.no_grad():\n",
    "        for jimgs, simgs, label, patient in tqdm(dataset, desc=f'{iw+1}/{len(weight_list)}', leave=True, ncols=50):\n",
    "            true_list.append(stanford[patient])\n",
    "            simgs = [torch.stack(simgs[i:i+7], dim=1) for i in range(len(simgs)-6)]\n",
    "            jimgs = [torch.stack(jimgs[i:i+7], dim=1) for i in range(len(jimgs)-6)]\n",
    "            imgs = torch.stack(simgs+jimgs, dim=0).to(device)\n",
    "            preds = torch.argmax(net(imgs),dim=1).tolist()\n",
    "            spreds, jpreds = preds[:len(simgs)], preds[len(simgs):]\n",
    "            sjpred_list.append( (spreds, jpreds) )\n",
    "    true_list = np.array(true_list)\n",
    "    for N in range(1, N_MAX+1):\n",
    "        pred_list = []\n",
    "        for spreds, jpreds in sjpred_list:\n",
    "            patient_label = 0\n",
    "            for isj, preds in enumerate([spreds, jpreds]):\n",
    "                pred_label = 0\n",
    "                pre = -1\n",
    "                max_len = -1\n",
    "                start_idx = -1\n",
    "                for i in range(len(preds)):\n",
    "                    cur = preds[i]\n",
    "                    if cur != pre:\n",
    "                        if pre in [1, 2]:\n",
    "                            ln = i - start_idx\n",
    "                            if ln >= N and ln > max_len:\n",
    "                                max_len = ln\n",
    "                                pred_label = pre\n",
    "                        start_idx = i\n",
    "                    pre = cur\n",
    "                if pred_label != 0:\n",
    "                    if isj == 0:\n",
    "                        patient_label = 1\n",
    "                        break\n",
    "                    else:\n",
    "                        patient_label = 2\n",
    "            pred_list.append(patient_label)\n",
    "        pred_list = np.array(pred_list)\n",
    "#         true_list[true_list == 2] =1\n",
    "#         pred_list[pred_list == 2] =1\n",
    "        print(f'{N} stanford report:\\n'+metrics.classification_report(true_list, pred_list, digits=4))\n",
    "    del net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcd91760",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "e0d8f6f1be8e0931392fbb561bd7947073d83a45626316b25b70c173951275ba"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
